aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorTaras Madan <tarasmadan@google.com>2024-08-21 12:42:58 +0200
committerTaras Madan <tarasmadan@google.com>2024-08-22 09:08:20 +0000
commitaa99fc3349e97ea596c31624efde306de4136241 (patch)
tree8ec3d255838cd3990170bd33bfc0d641c0aac418 /pkg
parentca02180f7c9d6b3a7de8a887f3998725ae2f0c51 (diff)
pkg/validator: initial code
Diffstat (limited to 'pkg')
-rw-r--r--pkg/covermerger/bq_csv_reader.go16
-rw-r--r--pkg/validator/validator.go97
-rw-r--r--pkg/validator/validator_test.go85
3 files changed, 189 insertions, 9 deletions
diff --git a/pkg/covermerger/bq_csv_reader.go b/pkg/covermerger/bq_csv_reader.go
index 68711e318..60f6e5829 100644
--- a/pkg/covermerger/bq_csv_reader.go
+++ b/pkg/covermerger/bq_csv_reader.go
@@ -8,20 +8,14 @@ import (
"context"
"fmt"
"io"
- "regexp"
"cloud.google.com/go/bigquery"
"cloud.google.com/go/civil"
"github.com/google/syzkaller/pkg/gcs"
+ "github.com/google/syzkaller/pkg/validator"
"github.com/google/uuid"
)
-var allowedFilePath = regexp.MustCompile(`^[./_a-zA-Z0-9]*$`)
-
-func isAllowedFilePath(s string) bool {
- return allowedFilePath.MatchString(s)
-}
-
type bqCSVReader struct {
closers []io.Closer
gcsFiles []io.Reader
@@ -38,8 +32,12 @@ func MakeBQCSVReader() *bqCSVReader {
}
func (r *bqCSVReader) InitNsRecords(ctx context.Context, ns, filePath, commit string, from, to civil.Date) error {
- if !isAllowedFilePath(filePath) {
- return fmt.Errorf("wrong file path '%s'", filePath)
+ if err := validator.AnyError("input validation failed",
+ validator.NamespaceName(ns),
+ validator.KernelFilePath(filePath),
+ validator.AnyOk(validator.EmptyStr(commit), validator.CommitHash(commit)),
+ ); err != nil {
+ return err
}
sessionUUID := uuid.New().String()
gsBucket := "syzbot-temp"
diff --git a/pkg/validator/validator.go b/pkg/validator/validator.go
new file mode 100644
index 000000000..420830835
--- /dev/null
+++ b/pkg/validator/validator.go
@@ -0,0 +1,97 @@
+// Copyright 2024 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package validator
+
+import (
+ "errors"
+ "fmt"
+ "regexp"
+
+ "github.com/google/syzkaller/pkg/auth"
+)
+
+type Result struct {
+ Ok bool
+ Err error
+}
+
+var ResultOk = Result{true, nil}
+
+func AnyError(errPrefix string, results ...Result) error {
+ for _, res := range results {
+ if !res.Ok {
+ return wrapError(res.Err.Error(), errPrefix)
+ }
+ }
+ return nil
+}
+
+func AnyOk(results ...Result) Result {
+ if len(results) == 0 {
+ return ResultOk
+ }
+ for _, res := range results {
+ if res.Ok {
+ return ResultOk
+ }
+ }
+ return results[0]
+}
+
+func PanicIfNot(results ...Result) error {
+ if err := AnyError("", results...); err != nil {
+ panic(err.Error())
+ }
+ return nil
+}
+
+var (
+ EmptyStr = makeStrLenFunc("not empty", 0)
+ AlphaNumeric = makeStrReFunc("not an alphanum", "^[a-zA-Z0-9]*$")
+ CommitHash = makeCombinedStrFunc("not a hash", AlphaNumeric, makeStrLenFunc("len is not 40", 40))
+ KernelFilePath = makeStrReFunc("not a kernel file path", "^[./_a-zA-Z0-9]*$")
+ NamespaceName = makeStrReFunc("not a namespace name", "^[a-zA-Z0-9-_.]{4,32}$")
+ DashClientName = makeStrReFunc("not a dashboard client name", "^[a-zA-Z0-9-_.]{4,100}$")
+ DashClientKey = makeStrReFunc("not a dashboard client key",
+ "^([a-zA-Z0-9]{16,128})|("+regexp.QuoteMeta(auth.OauthMagic)+".*)$")
+)
+
+type strValidationFunc func(string, ...string) Result
+
+func makeStrReFunc(errStr, reStr string) strValidationFunc {
+ matchRe := regexp.MustCompile(reStr)
+ return func(s string, objName ...string) Result {
+ if !matchRe.MatchString(s) {
+ return Result{false, wrapError(errStr, objName...)}
+ }
+ return ResultOk
+ }
+}
+
+func makeStrLenFunc(errStr string, l int) strValidationFunc {
+ return func(s string, objName ...string) Result {
+ if len(s) != l {
+ return Result{false, wrapError(errStr, objName...)}
+ }
+ return ResultOk
+ }
+}
+
+func makeCombinedStrFunc(errStr string, funcs ...strValidationFunc) strValidationFunc {
+ return func(s string, objName ...string) Result {
+ for _, f := range funcs {
+ if res := f(s); !res.Ok {
+ return Result{false, wrapError(fmt.Sprintf(errStr+": %s", res.Err.Error()), objName...)}
+ }
+ }
+ return ResultOk
+ }
+}
+
+func wrapError(errStr string, prefix ...string) error {
+ if len(prefix) > 0 && prefix[0] != "" {
+ return fmt.Errorf("%s: %s", prefix[0], errStr)
+ }
+ return errors.New(errStr)
+}
diff --git a/pkg/validator/validator_test.go b/pkg/validator/validator_test.go
new file mode 100644
index 000000000..ef04d9cde
--- /dev/null
+++ b/pkg/validator/validator_test.go
@@ -0,0 +1,85 @@
+// Copyright 2024 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package validator_test
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/google/syzkaller/pkg/validator"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestIsCommitHash(t *testing.T) {
+ assert.True(t, validator.CommitHash("b311c1b497e51a628aa89e7cb954481e5f9dced2").Ok)
+ assert.False(t, validator.CommitHash("").Ok)
+ assert.False(t, validator.CommitHash("b311").Ok)
+ assert.False(t, validator.CommitHash("+311c1b497e51a628aa89e7cb954481e5f9dced2").Ok)
+
+ assert.Equal(t, "not a hash: len is not 40", validator.CommitHash("b311").Err.Error())
+ assert.Equal(t, "valName: not a hash: len is not 40",
+ validator.CommitHash("b311", "valName").Err.Error())
+ assert.Equal(t, "valName: not a hash: not an alphanum",
+ validator.CommitHash("!311c1b497e51a628aa89e7cb954481e5f9dced2", "valName").Err.Error())
+}
+
+func TestIsNamespaceName(t *testing.T) {
+ assert.True(t, validator.NamespaceName("upstream").Ok)
+ assert.False(t, validator.NamespaceName("up").Ok)
+ assert.False(t, validator.NamespaceName("").Ok)
+
+ assert.Equal(t, "not a namespace name", validator.NamespaceName("up").Err.Error())
+ assert.Equal(t, "ns: not a namespace name",
+ validator.NamespaceName("up", "ns").Err.Error())
+}
+
+// nolint: dupl
+func TestIsDashboardClientName(t *testing.T) {
+ assert.True(t, validator.DashClientName("name").Ok)
+ assert.False(t, validator.DashClientName("").Ok)
+
+ assert.Equal(t, "not a dashboard client name", validator.DashClientName("cl").Err.Error())
+ assert.Equal(t, "client: not a dashboard client name",
+ validator.DashClientName("cl", "client").Err.Error())
+}
+
+// nolint: dupl
+func TestIsDashboardClientKey(t *testing.T) {
+ assert.True(t, validator.DashClientKey("b311c1b497e51a628aa89e7cb954481e5f9dced2").Ok)
+ assert.False(t, validator.DashClientKey("").Ok)
+
+ assert.Equal(t, "not a dashboard client key", validator.DashClientKey("key").Err.Error())
+ assert.Equal(t, "clientKey: not a dashboard client key",
+ validator.DashClientKey("clKey", "clientKey").Err.Error())
+}
+
+// nolint: dupl
+func TestIsKernelFilePath(t *testing.T) {
+ assert.True(t, validator.KernelFilePath("io_uring/advise.c").Ok)
+ assert.False(t, validator.KernelFilePath("io-uring/advise.c").Ok)
+
+ assert.Equal(t, "not a kernel file path", validator.KernelFilePath("io-uring").Err.Error())
+ assert.Equal(t, "kernelPath: not a kernel file path",
+ validator.KernelFilePath("io-uring", "kernelPath").Err.Error())
+}
+
+var badResult = validator.Result{false, errors.New("sample error")}
+
+func TestAnyError(t *testing.T) {
+ assert.Nil(t, validator.AnyError("prefix", validator.ResultOk, validator.ResultOk))
+ assert.Equal(t, "prefix: sample error",
+ validator.AnyError("prefix", validator.ResultOk, badResult).Error())
+}
+
+func TestPanicIfNot(t *testing.T) {
+ assert.NotPanics(t, func() { validator.PanicIfNot(validator.ResultOk, validator.ResultOk) })
+ assert.Panics(t, func() { validator.PanicIfNot(validator.ResultOk, badResult) })
+}
+
+func TestAnyOk(t *testing.T) {
+ assert.Equal(t, validator.ResultOk, validator.AnyOk())
+ assert.Equal(t, validator.ResultOk, validator.AnyOk(validator.ResultOk))
+ assert.Equal(t, badResult, validator.AnyOk(badResult))
+ assert.Equal(t, validator.ResultOk, validator.AnyOk(badResult, validator.ResultOk))
+}