diff options
Diffstat (limited to 'syz-hub')
| -rw-r--r-- | syz-hub/hub.go | 14 | ||||
| -rw-r--r-- | syz-hub/state/state.go | 122 | ||||
| -rw-r--r-- | syz-hub/state/state_test.go | 237 |
3 files changed, 263 insertions, 110 deletions
diff --git a/syz-hub/hub.go b/syz-hub/hub.go index db8dd506e..fb39c59b0 100644 --- a/syz-hub/hub.go +++ b/syz-hub/hub.go @@ -75,7 +75,7 @@ func (hub *Hub) Connect(a *rpctype.HubConnectArgs, r *int) error { log.Logf(0, "connect from %v: fresh=%v calls=%v corpus=%v", name, a.Fresh, len(a.Calls), len(a.Corpus)) - if err := hub.st.Connect(name, a.Fresh, a.Calls, a.Corpus); err != nil { + if err := hub.st.Connect(name, a.Domain, a.Fresh, a.Calls, a.Corpus); err != nil { log.Logf(0, "connect error: %v", err) return err } @@ -90,12 +90,18 @@ func (hub *Hub) Sync(a *rpctype.HubSyncArgs, r *rpctype.HubSyncRes) error { hub.mu.Lock() defer hub.mu.Unlock() - progs, more, err := hub.st.Sync(name, a.Add, a.Del) + domain, inputs, more, err := hub.st.Sync(name, a.Add, a.Del) if err != nil { log.Logf(0, "sync error: %v", err) return err } - r.Progs = progs + if domain != "" { + r.Inputs = inputs + } else { + for _, inp := range inputs { + r.Progs = append(r.Progs, inp.Prog) + } + } r.More = more for _, repro := range a.Repros { if err := hub.st.AddRepro(name, repro); err != nil { @@ -112,7 +118,7 @@ func (hub *Hub) Sync(a *rpctype.HubSyncArgs, r *rpctype.HubSyncRes) error { } } log.Logf(0, "sync from %v: recv: add=%v del=%v repros=%v; send: progs=%v repros=%v pending=%v", - name, len(a.Add), len(a.Del), len(a.Repros), len(r.Progs), len(r.Repros), more) + name, len(a.Add), len(a.Del), len(a.Repros), len(inputs), len(r.Repros), more) return nil } diff --git a/syz-hub/state/state.go b/syz-hub/state/state.go index dd722d80c..9ffe28cda 100644 --- a/syz-hub/state/state.go +++ b/syz-hub/state/state.go @@ -16,6 +16,7 @@ import ( "github.com/google/syzkaller/pkg/hash" "github.com/google/syzkaller/pkg/log" "github.com/google/syzkaller/pkg/osutil" + "github.com/google/syzkaller/pkg/rpctype" "github.com/google/syzkaller/prog" ) @@ -34,11 +35,13 @@ type State struct { // Manager represents one syz-manager instance. type Manager struct { name string + domain string corpusSeq uint64 reproSeq uint64 corpusFile string corpusSeqFile string reproSeqFile string + domainFile string ownRepros map[string]bool Connected time.Time Added int @@ -59,11 +62,11 @@ func Make(dir string) (*State, error) { osutil.MkdirAll(st.dir) var err error - st.Corpus, st.corpusSeq, err = loadDB(filepath.Join(st.dir, "corpus.db"), "corpus") + st.Corpus, st.corpusSeq, err = loadDB(filepath.Join(st.dir, "corpus.db"), "corpus", true) if err != nil { log.Fatal(err) } - st.Repros, st.reproSeq, err = loadDB(filepath.Join(st.dir, "repro.db"), "repro") + st.Repros, st.reproSeq, err = loadDB(filepath.Join(st.dir, "repro.db"), "repro", true) if err != nil { log.Fatal(err) } @@ -83,11 +86,21 @@ func Make(dir string) (*State, error) { log.Logf(0, "purging corpus...") st.purgeCorpus() log.Logf(0, "done, %v programs", len(st.Corpus.Records)) - return st, err } -func loadDB(file, name string) (*db.DB, uint64, error) { +func (st *State) Flush() { + if err := st.Corpus.Flush(); err != nil { + log.Logf(0, "failed to flush corpus database: %v", err) + } + for _, mgr := range st.Managers { + if err := mgr.Corpus.Flush(); err != nil { + log.Logf(0, "failed to flush corpus database: %v", err) + } + } +} + +func loadDB(file, name string, progs bool) (*db.DB, uint64, error) { log.Logf(0, "reading %v...", name) db, err := db.Open(file) if err != nil { @@ -96,21 +109,23 @@ func loadDB(file, name string) (*db.DB, uint64, error) { log.Logf(0, "read %v programs", len(db.Records)) var maxSeq uint64 for key, rec := range db.Records { - _, ncalls, err := prog.CallSet(rec.Val) - if err != nil { - log.Logf(0, "bad file: can't parse call set: %v", err) - db.Delete(key) - continue - } - if ncalls > prog.MaxCalls { - log.Logf(0, "bad file: too many calls: %v", ncalls) - db.Delete(key) - continue - } - if sig := hash.Hash(rec.Val); sig.String() != key { - log.Logf(0, "bad file: hash %v, want hash %v", key, sig.String()) - db.Delete(key) - continue + if progs { + _, ncalls, err := prog.CallSet(rec.Val) + if err != nil { + log.Logf(0, "bad file: can't parse call set: %v\n%q", err, rec.Val) + db.Delete(key) + continue + } + if ncalls > prog.MaxCalls { + log.Logf(0, "bad file: too many calls: %v", ncalls) + db.Delete(key) + continue + } + if sig := hash.Hash(rec.Val); sig.String() != key { + log.Logf(0, "bad file: hash %v, want hash %v", key, sig.String()) + db.Delete(key) + continue + } } if maxSeq < rec.Seq { maxSeq = rec.Seq @@ -130,6 +145,7 @@ func (st *State) createManager(name string) (*Manager, error) { corpusFile: filepath.Join(dir, "corpus.db"), corpusSeqFile: filepath.Join(dir, "seq"), reproSeqFile: filepath.Join(dir, "repro.seq"), + domainFile: filepath.Join(dir, "domain"), ownRepros: make(map[string]bool), } mgr.corpusSeq = loadSeqFile(mgr.corpusSeqFile) @@ -143,18 +159,20 @@ func (st *State) createManager(name string) (*Manager, error) { if st.reproSeq < mgr.reproSeq { st.reproSeq = mgr.reproSeq } - corpus, _, err := loadDB(mgr.corpusFile, name) + domainData, _ := ioutil.ReadFile(mgr.domainFile) + mgr.domain = string(domainData) + corpus, _, err := loadDB(mgr.corpusFile, name, false) if err != nil { return nil, fmt.Errorf("failed to open manager corpus %v: %v", mgr.corpusFile, err) } mgr.Corpus = corpus - log.Logf(0, "created manager %v: corpus=%v, corpusSeq=%v, reproSeq=%v", - mgr.name, len(mgr.Corpus.Records), mgr.corpusSeq, mgr.reproSeq) + log.Logf(0, "created manager %v: domain=%v corpus=%v, corpusSeq=%v, reproSeq=%v", + mgr.name, mgr.domain, len(mgr.Corpus.Records), mgr.corpusSeq, mgr.reproSeq) st.Managers[name] = mgr return mgr, nil } -func (st *State) Connect(name string, fresh bool, calls []string, corpus [][]byte) error { +func (st *State) Connect(name, domain string, fresh bool, calls []string, corpus [][]byte) error { mgr := st.Managers[name] if mgr == nil { var err error @@ -164,6 +182,8 @@ func (st *State) Connect(name string, fresh bool, calls []string, corpus [][]byt } } mgr.Connected = time.Now() + mgr.domain = domain + writeFile(mgr.domainFile, []byte(mgr.domain)) if fresh { mgr.corpusSeq = 0 mgr.reproSeq = st.reproSeq @@ -188,10 +208,10 @@ func (st *State) Connect(name string, fresh bool, calls []string, corpus [][]byt return nil } -func (st *State) Sync(name string, add [][]byte, del []string) ([][]byte, int, error) { +func (st *State) Sync(name string, add [][]byte, del []string) (string, []rpctype.HubInput, int, error) { mgr := st.Managers[name] if mgr == nil || mgr.Connected.IsZero() { - return nil, 0, fmt.Errorf("unconnected manager %v", name) + return "", nil, 0, fmt.Errorf("unconnected manager %v", name) } if len(del) != 0 { for _, sig := range del { @@ -207,7 +227,7 @@ func (st *State) Sync(name string, add [][]byte, del []string) ([][]byte, int, e mgr.Added += len(add) mgr.Deleted += len(del) mgr.New += len(progs) - return progs, more, err + return mgr.domain, progs, more, err } func (st *State) AddRepro(name string, repro []byte) error { @@ -278,11 +298,16 @@ func (st *State) PendingRepro(name string) ([]byte, error) { return repro, nil } -func (st *State) pendingInputs(mgr *Manager) ([][]byte, int, error) { +func (st *State) pendingInputs(mgr *Manager) ([]rpctype.HubInput, int, error) { if mgr.corpusSeq == st.corpusSeq { return nil, 0, nil } - var records []db.Record + type Record struct { + Key string + Val []byte + Seq uint64 + } + var records []Record for key, rec := range st.Corpus.Records { if mgr.corpusSeq >= rec.Seq { continue @@ -297,7 +322,7 @@ func (st *State) pendingInputs(mgr *Manager) ([][]byte, int, error) { if !managerSupportsAllCalls(mgr.Calls, calls) { continue } - records = append(records, rec) + records = append(records, Record{key, rec.Val, rec.Seq}) } maxSeq := st.corpusSeq more := 0 @@ -325,15 +350,50 @@ func (st *State) pendingInputs(mgr *Manager) ([][]byte, int, error) { more = len(records) - pos records = records[:pos] } - progs := make([][]byte, len(records)) + progs := make([]rpctype.HubInput, 0, len(records)) for _, rec := range records { - progs = append(progs, rec.Val) + domain := "" + for _, mgr1 := range st.Managers { + same := mgr1.domain == mgr.domain + if !same && domain != "" { + continue + } + if _, ok := mgr1.Corpus.Records[rec.Key]; !ok { + continue + } + domain = mgr1.domain + if same { + break + } + } + progs = append(progs, rpctype.HubInput{ + Domain: st.inputDomain(rec.Key, mgr.domain), + Prog: rec.Val, + }) } mgr.corpusSeq = maxSeq saveSeqFile(mgr.corpusSeqFile, mgr.corpusSeq) return progs, more, nil } +func (st *State) inputDomain(key, self string) string { + domain := "" + for _, mgr := range st.Managers { + same := mgr.domain == self + if !same && domain != "" { + continue + } + if _, ok := mgr.Corpus.Records[key]; !ok { + continue + } + domain = mgr.domain + if same { + break + } + } + return domain +} + func (st *State) addInputs(mgr *Manager, inputs [][]byte) { if len(inputs) == 0 { return diff --git a/syz-hub/state/state_test.go b/syz-hub/state/state_test.go index 0972ad6bc..8db7a30f6 100644 --- a/syz-hub/state/state_test.go +++ b/syz-hub/state/state_test.go @@ -4,111 +4,198 @@ package state import ( - "fmt" "io/ioutil" "os" - "path/filepath" - "runtime" + "sort" "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/syzkaller/pkg/rpctype" ) -func TestState(t *testing.T) { +type TestState struct { + t *testing.T + dir string + state *State +} + +func MakeTestState(t *testing.T) *TestState { + t.Parallel() dir, err := ioutil.TempDir("", "syz-hub-state-test") if err != nil { t.Fatalf("failed to create temp dir: %v", err) } - defer os.RemoveAll(dir) - - st, err := Make(dir) + state, err := Make(dir) if err != nil { + os.RemoveAll(dir) t.Fatalf("failed to make state: %v", err) } - _, _, err = st.Sync("foo", nil, nil) - if err == nil { - t.Fatalf("synced with unconnected manager") - } - calls := []string{"read", "write"} - if err := st.Connect("foo", false, calls, nil); err != nil { - t.Fatalf("Connect failed: %v", err) - } - _, _, err = st.Sync("foo", nil, nil) + return &TestState{t, dir, state} +} + +func (ts *TestState) Close() { + os.RemoveAll(ts.dir) +} + +func (ts *TestState) Reload() { + ts.state.Flush() + state, err := Make(ts.dir) if err != nil { - t.Fatalf("Sync failed: %v", err) + ts.t.Fatalf("failed to make state: %v", err) } + ts.state = state } -func TestRepro(t *testing.T) { - dir, err := ioutil.TempDir("", "syz-hub-state-test") - if err != nil { - t.Fatalf("failed to create temp dir: %v", err) +func (ts *TestState) Connect(name, domain string, fresh bool, calls []string, corpus [][]byte) { + ts.t.Helper() + if err := ts.state.Connect(name, domain, fresh, calls, corpus); err != nil { + ts.t.Fatalf("Connect failed: %v", err) } - defer os.RemoveAll(dir) +} - st, err := Make(dir) +func (ts *TestState) Sync(name string, add [][]byte, del []string) (string, []rpctype.HubInput, int) { + ts.t.Helper() + domain, inputs, pending, err := ts.state.Sync(name, add, del) if err != nil { - t.Fatalf("failed to make state: %v", err) - } + ts.t.Fatalf("Sync failed: %v", err) + } + sort.Slice(inputs, func(i, j int) bool { + if inputs[i].Domain != inputs[j].Domain { + return inputs[i].Domain < inputs[j].Domain + } + return string(inputs[i].Prog) < string(inputs[j].Prog) + }) + return domain, inputs, pending +} - if err := st.Connect("foo", false, []string{"open", "read", "write"}, nil); err != nil { - t.Fatalf("Connect failed: %v", err) +func (ts *TestState) AddRepro(name string, repro []byte) { + ts.t.Helper() + if err := ts.state.AddRepro(name, repro); err != nil { + ts.t.Fatalf("AddRepro failed: %v", err) } - if err := st.Connect("bar", false, []string{"open", "read", "close"}, nil); err != nil { - t.Fatalf("Connect failed: %v", err) +} + +func (ts *TestState) PendingRepro(name string) []byte { + ts.t.Helper() + repro, err := ts.state.PendingRepro(name) + if err != nil { + ts.t.Fatalf("PendingRepro failed: %v", err) } - checkPendingRepro(t, st, "foo", "") - checkPendingRepro(t, st, "bar", "") + return repro +} + +func TestBasic(t *testing.T) { + st := MakeTestState(t) + defer st.Close() - if err := st.AddRepro("foo", []byte("open()")); err != nil { - t.Fatalf("AddRepro failed: %v", err) + if _, _, _, err := st.state.Sync("foo", nil, nil); err == nil { + t.Fatalf("synced with unconnected manager") } - checkPendingRepro(t, st, "foo", "") - checkPendingRepro(t, st, "bar", "open()") - checkPendingRepro(t, st, "bar", "") + calls := []string{"read", "write"} + st.Connect("foo", "", false, calls, nil) + st.Sync("foo", nil, nil) +} + +func TestRepro(t *testing.T) { + st := MakeTestState(t) + defer st.Close() + + st.Connect("foo", "", false, []string{"open", "read", "write"}, nil) + st.Connect("bar", "", false, []string{"open", "read", "close"}, nil) + + expectPendingRepro := func(name, result string) { + t.Helper() + repro := st.PendingRepro(name) + if string(repro) != result { + t.Fatalf("PendingRepro returned %q, want %q", string(repro), result) + } + } + expectPendingRepro("foo", "") + expectPendingRepro("bar", "") + st.AddRepro("foo", []byte("open()")) + expectPendingRepro("foo", "") + expectPendingRepro("bar", "open()") + expectPendingRepro("bar", "") // This repro is already present. - if err := st.AddRepro("bar", []byte("open()")); err != nil { - t.Fatalf("AddRepro failed: %v", err) - } - if err := st.AddRepro("bar", []byte("read()")); err != nil { - t.Fatalf("AddRepro failed: %v", err) - } - if err := st.AddRepro("bar", []byte("open()\nread()")); err != nil { - t.Fatalf("AddRepro failed: %v", err) - } + st.AddRepro("bar", []byte("open()")) + st.AddRepro("bar", []byte("read()")) + st.AddRepro("bar", []byte("open()\nread()")) // This does not satisfy foo's call set. - if err := st.AddRepro("bar", []byte("close()")); err != nil { - t.Fatalf("AddRepro failed: %v", err) - } - checkPendingRepro(t, st, "bar", "") + st.AddRepro("bar", []byte("close()")) + expectPendingRepro("bar", "") // Check how persistence works. - st, err = Make(dir) - if err != nil { - t.Fatalf("failed to make state: %v", err) - } - if err := st.Connect("foo", false, []string{"open", "read", "write"}, nil); err != nil { - t.Fatalf("Connect failed: %v", err) - } - if err := st.Connect("bar", false, []string{"open", "read", "close"}, nil); err != nil { - t.Fatalf("Connect failed: %v", err) - } - checkPendingRepro(t, st, "bar", "") - checkPendingRepro(t, st, "foo", "read()") - checkPendingRepro(t, st, "foo", "open()\nread()") - checkPendingRepro(t, st, "foo", "") + st.Reload() + st.Connect("foo", "", false, []string{"open", "read", "write"}, nil) + st.Connect("bar", "", false, []string{"open", "read", "close"}, nil) + expectPendingRepro("bar", "") + expectPendingRepro("foo", "read()") + expectPendingRepro("foo", "open()\nread()") + expectPendingRepro("foo", "") } -func checkPendingRepro(t *testing.T, st *State, name, result string) { - repro, err := st.PendingRepro(name) - if err != nil { - t.Fatalf("\n%v: PendingRepro failed: %v", caller(1), err) - } - if string(repro) != result { - t.Fatalf("\n%v: PendingRepro returned %q, want %q", caller(1), string(repro), result) - } -} +func TestDomain(t *testing.T) { + st := MakeTestState(t) + defer st.Close() -func caller(skip int) string { - _, file, line, _ := runtime.Caller(skip + 1) - return fmt.Sprintf("%v:%v", filepath.Base(file), line) + st.Connect("client0", "", false, []string{"open"}, nil) + st.Connect("client1", "domain1", false, []string{"open"}, nil) + st.Connect("client2", "domain2", false, []string{"open"}, nil) + st.Connect("client3", "domain3", false, []string{"open"}, nil) + { + domain, inputs, pending := st.Sync("client0", [][]byte{[]byte("open(0x0)")}, nil) + if domain != "" || len(inputs) != 0 || pending != 0 { + t.Fatalf("bad sync result: %v, %v, %v", domain, inputs, pending) + } + } + { + domain, inputs, pending := st.Sync("client0", [][]byte{[]byte("open(0x1)")}, nil) + if domain != "" || len(inputs) != 0 || pending != 0 { + t.Fatalf("bad sync result: %v, %v, %v", domain, inputs, pending) + } + } + { + domain, inputs, pending := st.Sync("client1", [][]byte{[]byte("open(0x2)"), []byte("open(0x1)")}, nil) + if domain != "domain1" || pending != 0 { + t.Fatalf("bad sync result: %v, %v, %v", domain, inputs, pending) + } + if diff := cmp.Diff(inputs, []rpctype.HubInput{ + {Domain: "", Prog: []byte("open(0x0)")}, + }); diff != "" { + t.Fatal(diff) + } + } + { + _, inputs, _ := st.Sync("client2", [][]byte{[]byte("open(0x3)")}, nil) + if diff := cmp.Diff(inputs, []rpctype.HubInput{ + {Domain: "", Prog: []byte("open(0x0)")}, + {Domain: "domain1", Prog: []byte("open(0x1)")}, + {Domain: "domain1", Prog: []byte("open(0x2)")}, + }); diff != "" { + t.Fatal(diff) + } + } + { + _, inputs, _ := st.Sync("client0", nil, nil) + if diff := cmp.Diff(inputs, []rpctype.HubInput{ + {Domain: "domain1", Prog: []byte("open(0x2)")}, + {Domain: "domain2", Prog: []byte("open(0x3)")}, + }); diff != "" { + t.Fatal(diff) + } + } + st.Reload() + st.Connect("client3", "domain3", false, []string{"open"}, nil) + { + _, inputs, _ := st.Sync("client3", nil, nil) + if diff := cmp.Diff(inputs, []rpctype.HubInput{ + {Domain: "", Prog: []byte("open(0x0)")}, + {Domain: "domain1", Prog: []byte("open(0x1)")}, + {Domain: "domain1", Prog: []byte("open(0x2)")}, + {Domain: "domain2", Prog: []byte("open(0x3)")}, + }); diff != "" { + t.Fatal(diff) + } + } } |
