diff options
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/db/db.go | 19 | ||||
| -rw-r--r-- | pkg/db/db_test.go | 55 |
2 files changed, 64 insertions, 10 deletions
diff --git a/pkg/db/db.go b/pkg/db/db.go index b62882b60..c2082c0a2 100644 --- a/pkg/db/db.go +++ b/pkg/db/db.go @@ -37,6 +37,9 @@ type Record struct { Seq uint64 } +// Open opens the specified database file. +// If the database is corrupted and reading failed, then it returns an non-nil db +// with whatever records were recovered and a non-nil error at the same time. func Open(filename string) (*DB, error) { db := &DB{ filename: filename, @@ -46,17 +49,15 @@ func Open(filename string) (*DB, error) { return nil, err } defer f.Close() - db.Version, db.Records, db.uncompacted, err = deserializeDB(bufio.NewReader(f)) - if err != nil { + // Deserialization error is considered a "soft" error, + // but compact below ensures that the file is at least writable. + var deserializeErr error + db.Version, db.Records, db.uncompacted, deserializeErr = deserializeDB(bufio.NewReader(f)) + f.Close() // compact will rewrite the file, so close our descriptor + if err := db.compact(); err != nil { return nil, err } - if len(db.Records) == 0 || db.uncompacted/10*9 > len(db.Records) { - f.Close() // compact will rewrite the file, so close our descriptor - if err := db.compact(); err != nil { - return nil, err - } - } - return db, nil + return db, deserializeErr } func (db *DB) Save(key string, val []byte, seq uint64) { diff --git a/pkg/db/db_test.go b/pkg/db/db_test.go index 71b93fd64..7ff69b2b7 100644 --- a/pkg/db/db_test.go +++ b/pkg/db/db_test.go @@ -130,8 +130,61 @@ func TestOpenInvalid(t *testing.T) { if _, err := f.Write([]byte(`some invalid data`)); err != nil { t.Error(err) } - if _, err := Open(f.Name()); err == nil { + if db, err := Open(f.Name()); err == nil { t.Fatal("opened invalid db") + } else if db == nil { + t.Fatal("db is nil") + } +} + +func TestOpenInaccessible(t *testing.T) { + f, err := ioutil.TempFile("", "syz-db-test") + if err != nil { + t.Error(err) + } + f.Close() + os.Chmod(f.Name(), 0) + defer os.Chmod(f.Name(), 0777) + defer os.Remove(f.Name()) + if db, err := Open(f.Name()); err == nil { + t.Fatal("opened inaccessible db") + } else if db != nil { + t.Fatal("db is not nil") + } +} + +func TestOpenCorrupted(t *testing.T) { + fn := tempFile(t) + defer os.Remove(fn) + db, err := Open(fn) + if err != nil { + t.Fatalf("failed to open db: %v", err) + } + // Write 1000 records, then wipe half of the file and test that we + // (1) get an error, (2) still get 450-550 records. + for i := 0; i < 1000; i++ { + db.Save(fmt.Sprintf("%v", i), []byte{byte(i)}, 0) + } + if err := db.Flush(); err != nil { + t.Fatalf("failed to flush db: %v", err) + } + data, err := ioutil.ReadFile(fn) + if err != nil { + t.Fatalf("failed to read db: %v", err) + } + for i := len(data) / 2; i < len(data); i++ { + data[i] = 0 + } + if err := osutil.WriteFile(fn, data); err != nil { + t.Fatalf("failed to write db: %v", err) + } + db, err = Open(fn) + if err == nil { + t.Fatalf("no error for corrutped db") + } + t.Logf("records %v, error: %v", len(db.Records), err) + if len(db.Records) < 450 || len(db.Records) > 550 { + t.Fatalf("wrong record count: %v", len(db.Records)) } } |
