aboutsummaryrefslogtreecommitdiffstats
path: root/syz-cluster/pkg/db/spanner.go
diff options
context:
space:
mode:
Diffstat (limited to 'syz-cluster/pkg/db/spanner.go')
-rw-r--r--syz-cluster/pkg/db/spanner.go37
1 files changed, 23 insertions, 14 deletions
diff --git a/syz-cluster/pkg/db/spanner.go b/syz-cluster/pkg/db/spanner.go
index 8defc9541..39c483975 100644
--- a/syz-cluster/pkg/db/spanner.go
+++ b/syz-cluster/pkg/db/spanner.go
@@ -240,7 +240,7 @@ func runSpanner(bin string) (*exec.Cmd, string, error) {
return cmd, host, nil
}
-func readOne[T any](iter *spanner.RowIterator) (*T, error) {
+func readRow[T any](iter *spanner.RowIterator) (*T, error) {
row, err := iter.Next()
if err == iterator.Done {
return nil, nil
@@ -256,10 +256,20 @@ func readOne[T any](iter *spanner.RowIterator) (*T, error) {
return &obj, nil
}
-func readEntities[T any](iter *spanner.RowIterator) ([]*T, error) {
+type dbQuerier interface {
+ Query(context.Context, spanner.Statement) *spanner.RowIterator
+}
+
+func readEntity[T any](ctx context.Context, txn dbQuerier, stmt spanner.Statement) (*T, error) {
+ iter := txn.Query(ctx, stmt)
+ defer iter.Stop()
+ return readRow[T](iter)
+}
+
+func readRows[T any](iter *spanner.RowIterator) ([]*T, error) {
var ret []*T
for {
- obj, err := readOne[T](iter)
+ obj, err := readRow[T](iter)
if err != nil {
return nil, err
}
@@ -271,6 +281,12 @@ func readEntities[T any](iter *spanner.RowIterator) ([]*T, error) {
return ret, nil
}
+func readEntities[T any](ctx context.Context, txn dbQuerier, stmt spanner.Statement) ([]*T, error) {
+ iter := txn.Query(ctx, stmt)
+ defer iter.Stop()
+ return readRows[T](iter)
+}
+
const NoLimit = 0
func addLimit(stmt *spanner.Statement, limit int) {
@@ -291,9 +307,7 @@ func (g *genericEntityOps[EntityType, KeyType]) GetByID(ctx context.Context, key
SQL: "SELECT * FROM " + g.table + " WHERE " + g.keyField + "=@key",
Params: map[string]interface{}{"key": key},
}
- iter := g.client.Single().Query(ctx, stmt)
- defer iter.Stop()
- return readOne[EntityType](iter)
+ return readEntity[EntityType](ctx, g.client.Single(), stmt)
}
var ErrEntityNotFound = errors.New("entity not found")
@@ -302,13 +316,10 @@ func (g *genericEntityOps[EntityType, KeyType]) Update(ctx context.Context, key
cb func(*EntityType) error) error {
_, err := g.client.ReadWriteTransaction(ctx,
func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
- stmt := spanner.Statement{
+ entity, err := readEntity[EntityType](ctx, txn, spanner.Statement{
SQL: "SELECT * from `" + g.table + "` WHERE `" + g.keyField + "`=@key",
Params: map[string]interface{}{"key": key},
- }
- iter := txn.Query(ctx, stmt)
- entity, err := readOne[EntityType](iter)
- iter.Stop()
+ })
if err != nil {
return err
}
@@ -347,7 +358,5 @@ func (g *genericEntityOps[EntityType, KeyType]) Insert(ctx context.Context, obj
func (g *genericEntityOps[EntityType, KeyType]) readEntities(ctx context.Context, stmt spanner.Statement) (
[]*EntityType, error) {
- iter := g.client.Single().Query(ctx, stmt)
- defer iter.Stop()
- return readEntities[EntityType](iter)
+ return readEntities[EntityType](ctx, g.client.Single(), stmt)
}