aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAleksandr Nogikh <nogikh@google.com>2025-07-29 11:12:34 +0200
committerAleksandr Nogikh <nogikh@google.com>2025-07-29 09:54:26 +0000
commitea52e6c494b8bb65114b8f63180336dde1f98ef0 (patch)
treed4dc347bf70098eed34f37db18dd190ce123841c
parentd052a4c8ea018bce39f2ffeed6ce214d78c4381b (diff)
syz-cluster: refactor generic Spanner helpers
Extract the common "Query - ReadOne - close iterator" pattern into a separate method.
-rw-r--r--syz-cluster/pkg/db/build_repo.go4
-rw-r--r--syz-cluster/pkg/db/finding_repo.go18
-rw-r--r--syz-cluster/pkg/db/report_reply_repo.go28
-rw-r--r--syz-cluster/pkg/db/report_repo.go1
-rw-r--r--syz-cluster/pkg/db/series_repo.go39
-rw-r--r--syz-cluster/pkg/db/session_repo.go16
-rw-r--r--syz-cluster/pkg/db/session_test_repo.go37
-rw-r--r--syz-cluster/pkg/db/spanner.go37
8 files changed, 65 insertions, 115 deletions
diff --git a/syz-cluster/pkg/db/build_repo.go b/syz-cluster/pkg/db/build_repo.go
index 318144acc..6ce088a96 100644
--- a/syz-cluster/pkg/db/build_repo.go
+++ b/syz-cluster/pkg/db/build_repo.go
@@ -67,7 +67,5 @@ func (repo *BuildRepository) LastBuiltTree(ctx context.Context, params *LastBuil
stmt.Params["commit"] = params.Commit
}
stmt.SQL += " ORDER BY `CommitDate` DESC LIMIT 1"
- iter := repo.client.Single().Query(ctx, stmt)
- defer iter.Stop()
- return readOne[Build](iter)
+ return readEntity[Build](ctx, repo.client.Single(), stmt)
}
diff --git a/syz-cluster/pkg/db/finding_repo.go b/syz-cluster/pkg/db/finding_repo.go
index 3472a72bf..0c9e1e0b9 100644
--- a/syz-cluster/pkg/db/finding_repo.go
+++ b/syz-cluster/pkg/db/finding_repo.go
@@ -41,7 +41,7 @@ func (repo *FindingRepository) Store(ctx context.Context, id *FindingID,
_, err := repo.client.ReadWriteTransaction(ctx,
func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
// Query the existing finding, if it exists.
- stmt := spanner.Statement{
+ oldFinding, err := readEntity[Finding](ctx, txn, spanner.Statement{
SQL: "SELECT * from `Findings` WHERE `SessionID`=@sessionID " +
"AND `TestName` = @testName AND `Title`=@title",
Params: map[string]interface{}{
@@ -49,21 +49,15 @@ func (repo *FindingRepository) Store(ctx context.Context, id *FindingID,
"testName": id.TestName,
"title": id.Title,
},
- }
- iter := txn.Query(ctx, stmt)
- oldFinding, err := readOne[Finding](iter)
- iter.Stop()
+ })
if err != nil {
return err
}
// Query the Session object.
- stmt = spanner.Statement{
+ session, err := readEntity[Session](ctx, txn, spanner.Statement{
SQL: "SELECT * FROM `Sessions` WHERE `ID`=@id",
Params: map[string]interface{}{"id": id.SessionID},
- }
- iter = txn.Query(ctx, stmt)
- session, err := readOne[Session](iter)
- iter.Stop()
+ })
if err != nil {
return err
}
@@ -114,7 +108,5 @@ func (repo *FindingRepository) ListForSession(ctx context.Context, sessionID str
Params: map[string]interface{}{"session": sessionID},
}
addLimit(&stmt, limit)
- iter := repo.client.Single().Query(ctx, stmt)
- defer iter.Stop()
- return readEntities[Finding](iter)
+ return repo.readEntities(ctx, stmt)
}
diff --git a/syz-cluster/pkg/db/report_reply_repo.go b/syz-cluster/pkg/db/report_reply_repo.go
index 7e21403c2..e20bd6e7a 100644
--- a/syz-cluster/pkg/db/report_reply_repo.go
+++ b/syz-cluster/pkg/db/report_reply_repo.go
@@ -21,7 +21,10 @@ func NewReportReplyRepository(client *spanner.Client) *ReportReplyRepository {
}
func (repo *ReportReplyRepository) FindParentReportID(ctx context.Context, reporter, messageID string) (string, error) {
- stmt := spanner.Statement{
+ type result struct {
+ ReportID string `spanner:"ReportID"`
+ }
+ ret, err := readEntity[result](ctx, repo.client.Single(), spanner.Statement{
SQL: "SELECT `ReportReplies`.ReportID FROM `ReportReplies` " +
"JOIN `SessionReports` ON `SessionReports`.ID = `ReportReplies`.ReportID " +
"WHERE `ReportReplies`.MessageID = @messageID " +
@@ -30,14 +33,7 @@ func (repo *ReportReplyRepository) FindParentReportID(ctx context.Context, repor
"reporter": reporter,
"messageID": messageID,
},
- }
- iter := repo.client.Single().Query(ctx, stmt)
- defer iter.Stop()
-
- type result struct {
- ReportID string `spanner:"ReportID"`
- }
- ret, err := readOne[result](iter)
+ })
if err != nil {
return "", err
} else if ret != nil {
@@ -51,17 +47,14 @@ var ErrReportReplyExists = errors.New("the reply has already been recorded")
func (repo *ReportReplyRepository) Insert(ctx context.Context, reply *ReportReply) error {
_, err := repo.client.ReadWriteTransaction(ctx,
func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
- stmt := spanner.Statement{
+ entity, err := readEntity[ReportReply](ctx, txn, spanner.Statement{
SQL: "SELECT * from `ReportReplies` " +
"WHERE `ReportID`=@reportID AND `MessageID`=@messageID",
Params: map[string]interface{}{
"reportID": reply.ReportID,
"messageID": reply.MessageID,
},
- }
- iter := txn.Query(ctx, stmt)
- entity, err := readOne[ReportReply](iter)
- iter.Stop()
+ })
if err != nil {
return err
} else if entity != nil {
@@ -77,7 +70,7 @@ func (repo *ReportReplyRepository) Insert(ctx context.Context, reply *ReportRepl
}
func (repo *ReportReplyRepository) LastForReporter(ctx context.Context, reporter string) (*ReportReply, error) {
- stmt := spanner.Statement{
+ return readEntity[ReportReply](ctx, repo.client.Single(), spanner.Statement{
SQL: "SELECT `ReportReplies`.* FROM `ReportReplies` " +
"JOIN `SessionReports` ON `SessionReports`.ID=`ReportReplies`.ReportID " +
"WHERE `SessionReports`.Reporter=@reporter " +
@@ -85,8 +78,5 @@ func (repo *ReportReplyRepository) LastForReporter(ctx context.Context, reporter
Params: map[string]interface{}{
"reporter": reporter,
},
- }
- iter := repo.client.Single().Query(ctx, stmt)
- defer iter.Stop()
- return readOne[ReportReply](iter)
+ })
}
diff --git a/syz-cluster/pkg/db/report_repo.go b/syz-cluster/pkg/db/report_repo.go
index c99e4a69d..6e9f469f8 100644
--- a/syz-cluster/pkg/db/report_repo.go
+++ b/syz-cluster/pkg/db/report_repo.go
@@ -50,6 +50,7 @@ func (repo *ReportRepository) Insert(ctx context.Context, rep *SessionReport) er
return fmt.Errorf("failed to pick a non-existing report ID")
}
+// nolint: dupl
func (repo *ReportRepository) ListNotReported(ctx context.Context, reporter string,
limit int) ([]*SessionReport, error) {
stmt := spanner.Statement{
diff --git a/syz-cluster/pkg/db/series_repo.go b/syz-cluster/pkg/db/series_repo.go
index d79bb1d4e..0c6ce4f79 100644
--- a/syz-cluster/pkg/db/series_repo.go
+++ b/syz-cluster/pkg/db/series_repo.go
@@ -36,24 +36,18 @@ func NewSeriesRepository(client *spanner.Client) *SeriesRepository {
// TODO: move to SeriesPatchesRepository?
// nolint:dupl
func (repo *SeriesRepository) PatchByID(ctx context.Context, id string) (*Patch, error) {
- stmt := spanner.Statement{
+ return readEntity[Patch](ctx, repo.client.Single(), spanner.Statement{
SQL: "SELECT * FROM Patches WHERE ID=@id",
Params: map[string]interface{}{"id": id},
- }
- iter := repo.client.Single().Query(ctx, stmt)
- defer iter.Stop()
- return readOne[Patch](iter)
+ })
}
// nolint:dupl
func (repo *SeriesRepository) GetByExtID(ctx context.Context, extID string) (*Series, error) {
- stmt := spanner.Statement{
+ return readEntity[Series](ctx, repo.client.Single(), spanner.Statement{
SQL: "SELECT * FROM Series WHERE ExtID=@extID",
Params: map[string]interface{}{"extID": extID},
- }
- iter := repo.client.Single().Query(ctx, stmt)
- defer iter.Stop()
- return readOne[Series](iter)
+ })
}
var ErrSeriesExists = errors.New("the series already exists")
@@ -191,10 +185,7 @@ func (repo *SeriesRepository) ListLatest(ctx context.Context, filter SeriesFilte
stmt.SQL += " OFFSET @offset"
stmt.Params["offset"] = filter.Offset
}
- iter := ro.Query(ctx, stmt)
- defer iter.Stop()
-
- seriesList, err := readEntities[Series](iter)
+ seriesList, err := readEntities[Series](ctx, ro, stmt)
if err != nil {
return nil, err
}
@@ -232,14 +223,12 @@ func (repo *SeriesRepository) querySessions(ctx context.Context, ro *spanner.Rea
if len(keys) == 0 {
return nil
}
- iter := ro.Query(ctx, spanner.Statement{
+ sessions, err := readEntities[Session](ctx, ro, spanner.Statement{
SQL: "SELECT * FROM Sessions WHERE ID IN UNNEST(@ids)",
Params: map[string]interface{}{
"ids": keys,
},
})
- defer iter.Stop()
- sessions, err := readEntities[Session](iter)
if err != nil {
return err
}
@@ -271,18 +260,13 @@ func (repo *SeriesRepository) queryFindingCounts(ctx context.Context, ro *spanne
SessionID string `spanner:"SessionID"`
Count int64 `spanner:"Count"`
}
-
- stmt := spanner.Statement{
+ list, err := readEntities[findingCount](ctx, repo.client.Single(), spanner.Statement{
SQL: "SELECT `SessionID`, COUNT(`ID`) as `Count` FROM `Findings` " +
"WHERE `SessionID` IN UNNEST(@ids) GROUP BY `SessionID`",
Params: map[string]interface{}{
"ids": keys,
},
- }
- iter := repo.client.Single().Query(ctx, stmt)
- defer iter.Stop()
-
- list, err := readEntities[findingCount](iter)
+ })
if err != nil {
return err
}
@@ -295,13 +279,10 @@ func (repo *SeriesRepository) queryFindingCounts(ctx context.Context, ro *spanne
// golint sees too much similarity with SessionRepository's ListForSeries, but in reality there's not.
// nolint:dupl
func (repo *SeriesRepository) ListPatches(ctx context.Context, series *Series) ([]*Patch, error) {
- stmt := spanner.Statement{
+ return readEntities[Patch](ctx, repo.client.Single(), spanner.Statement{
SQL: "SELECT * FROM `Patches` WHERE `SeriesID` = @seriesID ORDER BY `Seq`",
Params: map[string]interface{}{
"seriesID": series.ID,
},
- }
- iter := repo.client.Single().Query(ctx, stmt)
- defer iter.Stop()
- return readEntities[Patch](iter)
+ })
}
diff --git a/syz-cluster/pkg/db/session_repo.go b/syz-cluster/pkg/db/session_repo.go
index f4fc3233c..3db95092b 100644
--- a/syz-cluster/pkg/db/session_repo.go
+++ b/syz-cluster/pkg/db/session_repo.go
@@ -33,12 +33,10 @@ var ErrSessionAlreadyStarted = errors.New("the session already started")
func (repo *SessionRepository) Start(ctx context.Context, sessionID string) error {
_, err := repo.client.ReadWriteTransaction(ctx,
func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
- iter := txn.Query(ctx, spanner.Statement{
+ session, err := readEntity[Session](ctx, txn, spanner.Statement{
SQL: "SELECT * from `Sessions` WHERE `ID`=@id",
Params: map[string]interface{}{"id": sessionID},
})
- session, err := readOne[Session](iter)
- iter.Stop()
if err != nil {
return err
}
@@ -50,12 +48,10 @@ func (repo *SessionRepository) Start(ctx context.Context, sessionID string) erro
if err != nil {
return err
}
- iter = txn.Query(ctx, spanner.Statement{
+ series, err := readEntity[Series](ctx, txn, spanner.Statement{
SQL: "SELECT * from `Series` WHERE `ID`=@id",
Params: map[string]interface{}{"id": session.SeriesID},
})
- series, err := readOne[Series](iter)
- iter.Stop()
if err != nil {
return err
}
@@ -77,11 +73,9 @@ func (repo *SessionRepository) Insert(ctx context.Context, session *Session) err
}
func (repo *SessionRepository) ListRunning(ctx context.Context) ([]*Session, error) {
- stmt := spanner.Statement{SQL: "SELECT * FROM `Sessions` WHERE `StartedAt` IS NOT NULL " +
- "AND `FinishedAt` IS NULL"}
- iter := repo.client.Single().Query(ctx, stmt)
- defer iter.Stop()
- return readEntities[Session](iter)
+ return repo.readEntities(ctx, spanner.Statement{
+ SQL: "SELECT * FROM `Sessions` WHERE `StartedAt` IS NOT NULL AND `FinishedAt` IS NULL",
+ })
}
type NextSession struct {
diff --git a/syz-cluster/pkg/db/session_test_repo.go b/syz-cluster/pkg/db/session_test_repo.go
index 7043b8389..03316c38c 100644
--- a/syz-cluster/pkg/db/session_test_repo.go
+++ b/syz-cluster/pkg/db/session_test_repo.go
@@ -7,7 +7,6 @@ import (
"context"
"cloud.google.com/go/spanner"
- "google.golang.org/api/iterator"
)
type SessionTestRepository struct {
@@ -26,20 +25,17 @@ func (repo *SessionTestRepository) InsertOrUpdate(ctx context.Context, test *Ses
_, err := repo.client.ReadWriteTransaction(ctx,
func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
// Check if the test already exists.
- stmt := spanner.Statement{
+ dbTest, err := readEntity[SessionTest](ctx, txn, spanner.Statement{
SQL: "SELECT * from `SessionTests` WHERE `SessionID`=@sessionID AND `TestName` = @testName",
Params: map[string]interface{}{
"sessionID": test.SessionID,
"testName": test.TestName,
},
- }
- iter := txn.Query(ctx, stmt)
- defer iter.Stop()
-
+ })
var stmts []*spanner.Mutation
-
- _, iterErr := iter.Next()
- if iterErr == nil {
+ if err != nil {
+ return err
+ } else if dbTest != nil {
if beforeSave != nil {
beforeSave(test)
}
@@ -48,8 +44,6 @@ func (repo *SessionTestRepository) InsertOrUpdate(ctx context.Context, test *Ses
return err
}
stmts = append(stmts, m)
- } else if iterErr != iterator.Done {
- return iterErr
} else {
if beforeSave != nil {
beforeSave(test)
@@ -66,16 +60,13 @@ func (repo *SessionTestRepository) InsertOrUpdate(ctx context.Context, test *Ses
}
func (repo *SessionTestRepository) Get(ctx context.Context, sessionID, testName string) (*SessionTest, error) {
- stmt := spanner.Statement{
+ return readEntity[SessionTest](ctx, repo.client.Single(), spanner.Statement{
SQL: "SELECT * FROM `SessionTests` WHERE `SessionID` = @session AND `TestName` = @name",
Params: map[string]interface{}{
"session": sessionID,
"name": testName,
},
- }
- iter := repo.client.Single().Query(ctx, stmt)
- defer iter.Stop()
- return readOne[SessionTest](iter)
+ })
}
type FullSessionTest struct {
@@ -106,13 +97,10 @@ func (repo *SessionTestRepository) BySession(ctx context.Context, sessionID stri
for key := range needBuilds {
keys = append(keys, key)
}
- stmt := spanner.Statement{
+ builds, err := readEntities[Build](ctx, repo.client.Single(), spanner.Statement{
SQL: "SELECT * FROM `Builds` WHERE `ID` IN UNNEST(@ids)",
Params: map[string]interface{}{"ids": keys},
- }
- iter := repo.client.Single().Query(ctx, stmt)
- defer iter.Stop()
- builds, err := readEntities[Build](iter)
+ })
if err != nil {
return nil, err
}
@@ -126,12 +114,9 @@ func (repo *SessionTestRepository) BySession(ctx context.Context, sessionID stri
}
func (repo *SessionTestRepository) BySessionRaw(ctx context.Context, sessionID string) ([]*SessionTest, error) {
- stmt := spanner.Statement{
+ return readEntities[SessionTest](ctx, repo.client.Single(), spanner.Statement{
SQL: "SELECT * FROM `SessionTests` WHERE `SessionID` = @session" +
" ORDER BY `UpdatedAt`",
Params: map[string]interface{}{"session": sessionID},
- }
- iter := repo.client.Single().Query(ctx, stmt)
- defer iter.Stop()
- return readEntities[SessionTest](iter)
+ })
}
diff --git a/syz-cluster/pkg/db/spanner.go b/syz-cluster/pkg/db/spanner.go
index 8defc9541..39c483975 100644
--- a/syz-cluster/pkg/db/spanner.go
+++ b/syz-cluster/pkg/db/spanner.go
@@ -240,7 +240,7 @@ func runSpanner(bin string) (*exec.Cmd, string, error) {
return cmd, host, nil
}
-func readOne[T any](iter *spanner.RowIterator) (*T, error) {
+func readRow[T any](iter *spanner.RowIterator) (*T, error) {
row, err := iter.Next()
if err == iterator.Done {
return nil, nil
@@ -256,10 +256,20 @@ func readOne[T any](iter *spanner.RowIterator) (*T, error) {
return &obj, nil
}
-func readEntities[T any](iter *spanner.RowIterator) ([]*T, error) {
+type dbQuerier interface {
+ Query(context.Context, spanner.Statement) *spanner.RowIterator
+}
+
+func readEntity[T any](ctx context.Context, txn dbQuerier, stmt spanner.Statement) (*T, error) {
+ iter := txn.Query(ctx, stmt)
+ defer iter.Stop()
+ return readRow[T](iter)
+}
+
+func readRows[T any](iter *spanner.RowIterator) ([]*T, error) {
var ret []*T
for {
- obj, err := readOne[T](iter)
+ obj, err := readRow[T](iter)
if err != nil {
return nil, err
}
@@ -271,6 +281,12 @@ func readEntities[T any](iter *spanner.RowIterator) ([]*T, error) {
return ret, nil
}
+func readEntities[T any](ctx context.Context, txn dbQuerier, stmt spanner.Statement) ([]*T, error) {
+ iter := txn.Query(ctx, stmt)
+ defer iter.Stop()
+ return readRows[T](iter)
+}
+
const NoLimit = 0
func addLimit(stmt *spanner.Statement, limit int) {
@@ -291,9 +307,7 @@ func (g *genericEntityOps[EntityType, KeyType]) GetByID(ctx context.Context, key
SQL: "SELECT * FROM " + g.table + " WHERE " + g.keyField + "=@key",
Params: map[string]interface{}{"key": key},
}
- iter := g.client.Single().Query(ctx, stmt)
- defer iter.Stop()
- return readOne[EntityType](iter)
+ return readEntity[EntityType](ctx, g.client.Single(), stmt)
}
var ErrEntityNotFound = errors.New("entity not found")
@@ -302,13 +316,10 @@ func (g *genericEntityOps[EntityType, KeyType]) Update(ctx context.Context, key
cb func(*EntityType) error) error {
_, err := g.client.ReadWriteTransaction(ctx,
func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
- stmt := spanner.Statement{
+ entity, err := readEntity[EntityType](ctx, txn, spanner.Statement{
SQL: "SELECT * from `" + g.table + "` WHERE `" + g.keyField + "`=@key",
Params: map[string]interface{}{"key": key},
- }
- iter := txn.Query(ctx, stmt)
- entity, err := readOne[EntityType](iter)
- iter.Stop()
+ })
if err != nil {
return err
}
@@ -347,7 +358,5 @@ func (g *genericEntityOps[EntityType, KeyType]) Insert(ctx context.Context, obj
func (g *genericEntityOps[EntityType, KeyType]) readEntities(ctx context.Context, stmt spanner.Statement) (
[]*EntityType, error) {
- iter := g.client.Single().Query(ctx, stmt)
- defer iter.Stop()
- return readEntities[EntityType](iter)
+ return readEntities[EntityType](ctx, g.client.Single(), stmt)
}