From ea52e6c494b8bb65114b8f63180336dde1f98ef0 Mon Sep 17 00:00:00 2001 From: Aleksandr Nogikh Date: Tue, 29 Jul 2025 11:12:34 +0200 Subject: syz-cluster: refactor generic Spanner helpers Extract the common "Query - ReadOne - close iterator" pattern into a separate method. --- syz-cluster/pkg/db/build_repo.go | 4 +--- syz-cluster/pkg/db/finding_repo.go | 18 +++++---------- syz-cluster/pkg/db/report_reply_repo.go | 28 ++++++++--------------- syz-cluster/pkg/db/report_repo.go | 1 + syz-cluster/pkg/db/series_repo.go | 39 +++++++++------------------------ syz-cluster/pkg/db/session_repo.go | 16 +++++--------- syz-cluster/pkg/db/session_test_repo.go | 37 ++++++++++--------------------- syz-cluster/pkg/db/spanner.go | 37 +++++++++++++++++++------------ 8 files changed, 65 insertions(+), 115 deletions(-) diff --git a/syz-cluster/pkg/db/build_repo.go b/syz-cluster/pkg/db/build_repo.go index 318144acc..6ce088a96 100644 --- a/syz-cluster/pkg/db/build_repo.go +++ b/syz-cluster/pkg/db/build_repo.go @@ -67,7 +67,5 @@ func (repo *BuildRepository) LastBuiltTree(ctx context.Context, params *LastBuil stmt.Params["commit"] = params.Commit } stmt.SQL += " ORDER BY `CommitDate` DESC LIMIT 1" - iter := repo.client.Single().Query(ctx, stmt) - defer iter.Stop() - return readOne[Build](iter) + return readEntity[Build](ctx, repo.client.Single(), stmt) } diff --git a/syz-cluster/pkg/db/finding_repo.go b/syz-cluster/pkg/db/finding_repo.go index 3472a72bf..0c9e1e0b9 100644 --- a/syz-cluster/pkg/db/finding_repo.go +++ b/syz-cluster/pkg/db/finding_repo.go @@ -41,7 +41,7 @@ func (repo *FindingRepository) Store(ctx context.Context, id *FindingID, _, err := repo.client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { // Query the existing finding, if it exists. - stmt := spanner.Statement{ + oldFinding, err := readEntity[Finding](ctx, txn, spanner.Statement{ SQL: "SELECT * from `Findings` WHERE `SessionID`=@sessionID " + "AND `TestName` = @testName AND `Title`=@title", Params: map[string]interface{}{ @@ -49,21 +49,15 @@ func (repo *FindingRepository) Store(ctx context.Context, id *FindingID, "testName": id.TestName, "title": id.Title, }, - } - iter := txn.Query(ctx, stmt) - oldFinding, err := readOne[Finding](iter) - iter.Stop() + }) if err != nil { return err } // Query the Session object. - stmt = spanner.Statement{ + session, err := readEntity[Session](ctx, txn, spanner.Statement{ SQL: "SELECT * FROM `Sessions` WHERE `ID`=@id", Params: map[string]interface{}{"id": id.SessionID}, - } - iter = txn.Query(ctx, stmt) - session, err := readOne[Session](iter) - iter.Stop() + }) if err != nil { return err } @@ -114,7 +108,5 @@ func (repo *FindingRepository) ListForSession(ctx context.Context, sessionID str Params: map[string]interface{}{"session": sessionID}, } addLimit(&stmt, limit) - iter := repo.client.Single().Query(ctx, stmt) - defer iter.Stop() - return readEntities[Finding](iter) + return repo.readEntities(ctx, stmt) } diff --git a/syz-cluster/pkg/db/report_reply_repo.go b/syz-cluster/pkg/db/report_reply_repo.go index 7e21403c2..e20bd6e7a 100644 --- a/syz-cluster/pkg/db/report_reply_repo.go +++ b/syz-cluster/pkg/db/report_reply_repo.go @@ -21,7 +21,10 @@ func NewReportReplyRepository(client *spanner.Client) *ReportReplyRepository { } func (repo *ReportReplyRepository) FindParentReportID(ctx context.Context, reporter, messageID string) (string, error) { - stmt := spanner.Statement{ + type result struct { + ReportID string `spanner:"ReportID"` + } + ret, err := readEntity[result](ctx, repo.client.Single(), spanner.Statement{ SQL: "SELECT `ReportReplies`.ReportID FROM `ReportReplies` " + "JOIN `SessionReports` ON `SessionReports`.ID = `ReportReplies`.ReportID " + "WHERE `ReportReplies`.MessageID = @messageID " + @@ -30,14 +33,7 @@ func (repo *ReportReplyRepository) FindParentReportID(ctx context.Context, repor "reporter": reporter, "messageID": messageID, }, - } - iter := repo.client.Single().Query(ctx, stmt) - defer iter.Stop() - - type result struct { - ReportID string `spanner:"ReportID"` - } - ret, err := readOne[result](iter) + }) if err != nil { return "", err } else if ret != nil { @@ -51,17 +47,14 @@ var ErrReportReplyExists = errors.New("the reply has already been recorded") func (repo *ReportReplyRepository) Insert(ctx context.Context, reply *ReportReply) error { _, err := repo.client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { - stmt := spanner.Statement{ + entity, err := readEntity[ReportReply](ctx, txn, spanner.Statement{ SQL: "SELECT * from `ReportReplies` " + "WHERE `ReportID`=@reportID AND `MessageID`=@messageID", Params: map[string]interface{}{ "reportID": reply.ReportID, "messageID": reply.MessageID, }, - } - iter := txn.Query(ctx, stmt) - entity, err := readOne[ReportReply](iter) - iter.Stop() + }) if err != nil { return err } else if entity != nil { @@ -77,7 +70,7 @@ func (repo *ReportReplyRepository) Insert(ctx context.Context, reply *ReportRepl } func (repo *ReportReplyRepository) LastForReporter(ctx context.Context, reporter string) (*ReportReply, error) { - stmt := spanner.Statement{ + return readEntity[ReportReply](ctx, repo.client.Single(), spanner.Statement{ SQL: "SELECT `ReportReplies`.* FROM `ReportReplies` " + "JOIN `SessionReports` ON `SessionReports`.ID=`ReportReplies`.ReportID " + "WHERE `SessionReports`.Reporter=@reporter " + @@ -85,8 +78,5 @@ func (repo *ReportReplyRepository) LastForReporter(ctx context.Context, reporter Params: map[string]interface{}{ "reporter": reporter, }, - } - iter := repo.client.Single().Query(ctx, stmt) - defer iter.Stop() - return readOne[ReportReply](iter) + }) } diff --git a/syz-cluster/pkg/db/report_repo.go b/syz-cluster/pkg/db/report_repo.go index c99e4a69d..6e9f469f8 100644 --- a/syz-cluster/pkg/db/report_repo.go +++ b/syz-cluster/pkg/db/report_repo.go @@ -50,6 +50,7 @@ func (repo *ReportRepository) Insert(ctx context.Context, rep *SessionReport) er return fmt.Errorf("failed to pick a non-existing report ID") } +// nolint: dupl func (repo *ReportRepository) ListNotReported(ctx context.Context, reporter string, limit int) ([]*SessionReport, error) { stmt := spanner.Statement{ diff --git a/syz-cluster/pkg/db/series_repo.go b/syz-cluster/pkg/db/series_repo.go index d79bb1d4e..0c6ce4f79 100644 --- a/syz-cluster/pkg/db/series_repo.go +++ b/syz-cluster/pkg/db/series_repo.go @@ -36,24 +36,18 @@ func NewSeriesRepository(client *spanner.Client) *SeriesRepository { // TODO: move to SeriesPatchesRepository? // nolint:dupl func (repo *SeriesRepository) PatchByID(ctx context.Context, id string) (*Patch, error) { - stmt := spanner.Statement{ + return readEntity[Patch](ctx, repo.client.Single(), spanner.Statement{ SQL: "SELECT * FROM Patches WHERE ID=@id", Params: map[string]interface{}{"id": id}, - } - iter := repo.client.Single().Query(ctx, stmt) - defer iter.Stop() - return readOne[Patch](iter) + }) } // nolint:dupl func (repo *SeriesRepository) GetByExtID(ctx context.Context, extID string) (*Series, error) { - stmt := spanner.Statement{ + return readEntity[Series](ctx, repo.client.Single(), spanner.Statement{ SQL: "SELECT * FROM Series WHERE ExtID=@extID", Params: map[string]interface{}{"extID": extID}, - } - iter := repo.client.Single().Query(ctx, stmt) - defer iter.Stop() - return readOne[Series](iter) + }) } var ErrSeriesExists = errors.New("the series already exists") @@ -191,10 +185,7 @@ func (repo *SeriesRepository) ListLatest(ctx context.Context, filter SeriesFilte stmt.SQL += " OFFSET @offset" stmt.Params["offset"] = filter.Offset } - iter := ro.Query(ctx, stmt) - defer iter.Stop() - - seriesList, err := readEntities[Series](iter) + seriesList, err := readEntities[Series](ctx, ro, stmt) if err != nil { return nil, err } @@ -232,14 +223,12 @@ func (repo *SeriesRepository) querySessions(ctx context.Context, ro *spanner.Rea if len(keys) == 0 { return nil } - iter := ro.Query(ctx, spanner.Statement{ + sessions, err := readEntities[Session](ctx, ro, spanner.Statement{ SQL: "SELECT * FROM Sessions WHERE ID IN UNNEST(@ids)", Params: map[string]interface{}{ "ids": keys, }, }) - defer iter.Stop() - sessions, err := readEntities[Session](iter) if err != nil { return err } @@ -271,18 +260,13 @@ func (repo *SeriesRepository) queryFindingCounts(ctx context.Context, ro *spanne SessionID string `spanner:"SessionID"` Count int64 `spanner:"Count"` } - - stmt := spanner.Statement{ + list, err := readEntities[findingCount](ctx, repo.client.Single(), spanner.Statement{ SQL: "SELECT `SessionID`, COUNT(`ID`) as `Count` FROM `Findings` " + "WHERE `SessionID` IN UNNEST(@ids) GROUP BY `SessionID`", Params: map[string]interface{}{ "ids": keys, }, - } - iter := repo.client.Single().Query(ctx, stmt) - defer iter.Stop() - - list, err := readEntities[findingCount](iter) + }) if err != nil { return err } @@ -295,13 +279,10 @@ func (repo *SeriesRepository) queryFindingCounts(ctx context.Context, ro *spanne // golint sees too much similarity with SessionRepository's ListForSeries, but in reality there's not. // nolint:dupl func (repo *SeriesRepository) ListPatches(ctx context.Context, series *Series) ([]*Patch, error) { - stmt := spanner.Statement{ + return readEntities[Patch](ctx, repo.client.Single(), spanner.Statement{ SQL: "SELECT * FROM `Patches` WHERE `SeriesID` = @seriesID ORDER BY `Seq`", Params: map[string]interface{}{ "seriesID": series.ID, }, - } - iter := repo.client.Single().Query(ctx, stmt) - defer iter.Stop() - return readEntities[Patch](iter) + }) } diff --git a/syz-cluster/pkg/db/session_repo.go b/syz-cluster/pkg/db/session_repo.go index f4fc3233c..3db95092b 100644 --- a/syz-cluster/pkg/db/session_repo.go +++ b/syz-cluster/pkg/db/session_repo.go @@ -33,12 +33,10 @@ var ErrSessionAlreadyStarted = errors.New("the session already started") func (repo *SessionRepository) Start(ctx context.Context, sessionID string) error { _, err := repo.client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { - iter := txn.Query(ctx, spanner.Statement{ + session, err := readEntity[Session](ctx, txn, spanner.Statement{ SQL: "SELECT * from `Sessions` WHERE `ID`=@id", Params: map[string]interface{}{"id": sessionID}, }) - session, err := readOne[Session](iter) - iter.Stop() if err != nil { return err } @@ -50,12 +48,10 @@ func (repo *SessionRepository) Start(ctx context.Context, sessionID string) erro if err != nil { return err } - iter = txn.Query(ctx, spanner.Statement{ + series, err := readEntity[Series](ctx, txn, spanner.Statement{ SQL: "SELECT * from `Series` WHERE `ID`=@id", Params: map[string]interface{}{"id": session.SeriesID}, }) - series, err := readOne[Series](iter) - iter.Stop() if err != nil { return err } @@ -77,11 +73,9 @@ func (repo *SessionRepository) Insert(ctx context.Context, session *Session) err } func (repo *SessionRepository) ListRunning(ctx context.Context) ([]*Session, error) { - stmt := spanner.Statement{SQL: "SELECT * FROM `Sessions` WHERE `StartedAt` IS NOT NULL " + - "AND `FinishedAt` IS NULL"} - iter := repo.client.Single().Query(ctx, stmt) - defer iter.Stop() - return readEntities[Session](iter) + return repo.readEntities(ctx, spanner.Statement{ + SQL: "SELECT * FROM `Sessions` WHERE `StartedAt` IS NOT NULL AND `FinishedAt` IS NULL", + }) } type NextSession struct { diff --git a/syz-cluster/pkg/db/session_test_repo.go b/syz-cluster/pkg/db/session_test_repo.go index 7043b8389..03316c38c 100644 --- a/syz-cluster/pkg/db/session_test_repo.go +++ b/syz-cluster/pkg/db/session_test_repo.go @@ -7,7 +7,6 @@ import ( "context" "cloud.google.com/go/spanner" - "google.golang.org/api/iterator" ) type SessionTestRepository struct { @@ -26,20 +25,17 @@ func (repo *SessionTestRepository) InsertOrUpdate(ctx context.Context, test *Ses _, err := repo.client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { // Check if the test already exists. - stmt := spanner.Statement{ + dbTest, err := readEntity[SessionTest](ctx, txn, spanner.Statement{ SQL: "SELECT * from `SessionTests` WHERE `SessionID`=@sessionID AND `TestName` = @testName", Params: map[string]interface{}{ "sessionID": test.SessionID, "testName": test.TestName, }, - } - iter := txn.Query(ctx, stmt) - defer iter.Stop() - + }) var stmts []*spanner.Mutation - - _, iterErr := iter.Next() - if iterErr == nil { + if err != nil { + return err + } else if dbTest != nil { if beforeSave != nil { beforeSave(test) } @@ -48,8 +44,6 @@ func (repo *SessionTestRepository) InsertOrUpdate(ctx context.Context, test *Ses return err } stmts = append(stmts, m) - } else if iterErr != iterator.Done { - return iterErr } else { if beforeSave != nil { beforeSave(test) @@ -66,16 +60,13 @@ func (repo *SessionTestRepository) InsertOrUpdate(ctx context.Context, test *Ses } func (repo *SessionTestRepository) Get(ctx context.Context, sessionID, testName string) (*SessionTest, error) { - stmt := spanner.Statement{ + return readEntity[SessionTest](ctx, repo.client.Single(), spanner.Statement{ SQL: "SELECT * FROM `SessionTests` WHERE `SessionID` = @session AND `TestName` = @name", Params: map[string]interface{}{ "session": sessionID, "name": testName, }, - } - iter := repo.client.Single().Query(ctx, stmt) - defer iter.Stop() - return readOne[SessionTest](iter) + }) } type FullSessionTest struct { @@ -106,13 +97,10 @@ func (repo *SessionTestRepository) BySession(ctx context.Context, sessionID stri for key := range needBuilds { keys = append(keys, key) } - stmt := spanner.Statement{ + builds, err := readEntities[Build](ctx, repo.client.Single(), spanner.Statement{ SQL: "SELECT * FROM `Builds` WHERE `ID` IN UNNEST(@ids)", Params: map[string]interface{}{"ids": keys}, - } - iter := repo.client.Single().Query(ctx, stmt) - defer iter.Stop() - builds, err := readEntities[Build](iter) + }) if err != nil { return nil, err } @@ -126,12 +114,9 @@ func (repo *SessionTestRepository) BySession(ctx context.Context, sessionID stri } func (repo *SessionTestRepository) BySessionRaw(ctx context.Context, sessionID string) ([]*SessionTest, error) { - stmt := spanner.Statement{ + return readEntities[SessionTest](ctx, repo.client.Single(), spanner.Statement{ SQL: "SELECT * FROM `SessionTests` WHERE `SessionID` = @session" + " ORDER BY `UpdatedAt`", Params: map[string]interface{}{"session": sessionID}, - } - iter := repo.client.Single().Query(ctx, stmt) - defer iter.Stop() - return readEntities[SessionTest](iter) + }) } diff --git a/syz-cluster/pkg/db/spanner.go b/syz-cluster/pkg/db/spanner.go index 8defc9541..39c483975 100644 --- a/syz-cluster/pkg/db/spanner.go +++ b/syz-cluster/pkg/db/spanner.go @@ -240,7 +240,7 @@ func runSpanner(bin string) (*exec.Cmd, string, error) { return cmd, host, nil } -func readOne[T any](iter *spanner.RowIterator) (*T, error) { +func readRow[T any](iter *spanner.RowIterator) (*T, error) { row, err := iter.Next() if err == iterator.Done { return nil, nil @@ -256,10 +256,20 @@ func readOne[T any](iter *spanner.RowIterator) (*T, error) { return &obj, nil } -func readEntities[T any](iter *spanner.RowIterator) ([]*T, error) { +type dbQuerier interface { + Query(context.Context, spanner.Statement) *spanner.RowIterator +} + +func readEntity[T any](ctx context.Context, txn dbQuerier, stmt spanner.Statement) (*T, error) { + iter := txn.Query(ctx, stmt) + defer iter.Stop() + return readRow[T](iter) +} + +func readRows[T any](iter *spanner.RowIterator) ([]*T, error) { var ret []*T for { - obj, err := readOne[T](iter) + obj, err := readRow[T](iter) if err != nil { return nil, err } @@ -271,6 +281,12 @@ func readEntities[T any](iter *spanner.RowIterator) ([]*T, error) { return ret, nil } +func readEntities[T any](ctx context.Context, txn dbQuerier, stmt spanner.Statement) ([]*T, error) { + iter := txn.Query(ctx, stmt) + defer iter.Stop() + return readRows[T](iter) +} + const NoLimit = 0 func addLimit(stmt *spanner.Statement, limit int) { @@ -291,9 +307,7 @@ func (g *genericEntityOps[EntityType, KeyType]) GetByID(ctx context.Context, key SQL: "SELECT * FROM " + g.table + " WHERE " + g.keyField + "=@key", Params: map[string]interface{}{"key": key}, } - iter := g.client.Single().Query(ctx, stmt) - defer iter.Stop() - return readOne[EntityType](iter) + return readEntity[EntityType](ctx, g.client.Single(), stmt) } var ErrEntityNotFound = errors.New("entity not found") @@ -302,13 +316,10 @@ func (g *genericEntityOps[EntityType, KeyType]) Update(ctx context.Context, key cb func(*EntityType) error) error { _, err := g.client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { - stmt := spanner.Statement{ + entity, err := readEntity[EntityType](ctx, txn, spanner.Statement{ SQL: "SELECT * from `" + g.table + "` WHERE `" + g.keyField + "`=@key", Params: map[string]interface{}{"key": key}, - } - iter := txn.Query(ctx, stmt) - entity, err := readOne[EntityType](iter) - iter.Stop() + }) if err != nil { return err } @@ -347,7 +358,5 @@ func (g *genericEntityOps[EntityType, KeyType]) Insert(ctx context.Context, obj func (g *genericEntityOps[EntityType, KeyType]) readEntities(ctx context.Context, stmt spanner.Statement) ( []*EntityType, error) { - iter := g.client.Single().Query(ctx, stmt) - defer iter.Stop() - return readEntities[EntityType](iter) + return readEntities[EntityType](ctx, g.client.Single(), stmt) } -- cgit mrf-deployment