aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/db/db.go19
-rw-r--r--pkg/db/db_test.go55
2 files changed, 64 insertions, 10 deletions
diff --git a/pkg/db/db.go b/pkg/db/db.go
index b62882b60..c2082c0a2 100644
--- a/pkg/db/db.go
+++ b/pkg/db/db.go
@@ -37,6 +37,9 @@ type Record struct {
Seq uint64
}
+// Open opens the specified database file.
+// If the database is corrupted and reading failed, then it returns an non-nil db
+// with whatever records were recovered and a non-nil error at the same time.
func Open(filename string) (*DB, error) {
db := &DB{
filename: filename,
@@ -46,17 +49,15 @@ func Open(filename string) (*DB, error) {
return nil, err
}
defer f.Close()
- db.Version, db.Records, db.uncompacted, err = deserializeDB(bufio.NewReader(f))
- if err != nil {
+ // Deserialization error is considered a "soft" error,
+ // but compact below ensures that the file is at least writable.
+ var deserializeErr error
+ db.Version, db.Records, db.uncompacted, deserializeErr = deserializeDB(bufio.NewReader(f))
+ f.Close() // compact will rewrite the file, so close our descriptor
+ if err := db.compact(); err != nil {
return nil, err
}
- if len(db.Records) == 0 || db.uncompacted/10*9 > len(db.Records) {
- f.Close() // compact will rewrite the file, so close our descriptor
- if err := db.compact(); err != nil {
- return nil, err
- }
- }
- return db, nil
+ return db, deserializeErr
}
func (db *DB) Save(key string, val []byte, seq uint64) {
diff --git a/pkg/db/db_test.go b/pkg/db/db_test.go
index 71b93fd64..7ff69b2b7 100644
--- a/pkg/db/db_test.go
+++ b/pkg/db/db_test.go
@@ -130,8 +130,61 @@ func TestOpenInvalid(t *testing.T) {
if _, err := f.Write([]byte(`some invalid data`)); err != nil {
t.Error(err)
}
- if _, err := Open(f.Name()); err == nil {
+ if db, err := Open(f.Name()); err == nil {
t.Fatal("opened invalid db")
+ } else if db == nil {
+ t.Fatal("db is nil")
+ }
+}
+
+func TestOpenInaccessible(t *testing.T) {
+ f, err := ioutil.TempFile("", "syz-db-test")
+ if err != nil {
+ t.Error(err)
+ }
+ f.Close()
+ os.Chmod(f.Name(), 0)
+ defer os.Chmod(f.Name(), 0777)
+ defer os.Remove(f.Name())
+ if db, err := Open(f.Name()); err == nil {
+ t.Fatal("opened inaccessible db")
+ } else if db != nil {
+ t.Fatal("db is not nil")
+ }
+}
+
+func TestOpenCorrupted(t *testing.T) {
+ fn := tempFile(t)
+ defer os.Remove(fn)
+ db, err := Open(fn)
+ if err != nil {
+ t.Fatalf("failed to open db: %v", err)
+ }
+ // Write 1000 records, then wipe half of the file and test that we
+ // (1) get an error, (2) still get 450-550 records.
+ for i := 0; i < 1000; i++ {
+ db.Save(fmt.Sprintf("%v", i), []byte{byte(i)}, 0)
+ }
+ if err := db.Flush(); err != nil {
+ t.Fatalf("failed to flush db: %v", err)
+ }
+ data, err := ioutil.ReadFile(fn)
+ if err != nil {
+ t.Fatalf("failed to read db: %v", err)
+ }
+ for i := len(data) / 2; i < len(data); i++ {
+ data[i] = 0
+ }
+ if err := osutil.WriteFile(fn, data); err != nil {
+ t.Fatalf("failed to write db: %v", err)
+ }
+ db, err = Open(fn)
+ if err == nil {
+ t.Fatalf("no error for corrutped db")
+ }
+ t.Logf("records %v, error: %v", len(db.Records), err)
+ if len(db.Records) < 450 || len(db.Records) > 550 {
+ t.Fatalf("wrong record count: %v", len(db.Records))
}
}