diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2021-02-22 20:37:25 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2021-02-22 21:02:12 +0100 |
| commit | fcc6d71be2c3ce7d9305c04fc2e87af554571bac (patch) | |
| tree | b01dbb3d1e2988e28ea158d2d543d603ec0b9569 /vendor/github.com/charithe/durationcheck/durationcheck.go | |
| parent | 8f23c528ad5a943b9ffec5dcaf332fd0f614006e (diff) | |
go.mod: update golangci-lint to v1.37
Diffstat (limited to 'vendor/github.com/charithe/durationcheck/durationcheck.go')
| -rw-r--r-- | vendor/github.com/charithe/durationcheck/durationcheck.go | 176 |
1 files changed, 176 insertions, 0 deletions
diff --git a/vendor/github.com/charithe/durationcheck/durationcheck.go b/vendor/github.com/charithe/durationcheck/durationcheck.go new file mode 100644 index 000000000..6eccd9c2a --- /dev/null +++ b/vendor/github.com/charithe/durationcheck/durationcheck.go @@ -0,0 +1,176 @@ +package durationcheck + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/token" + "go/types" + "log" + "os" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" +) + +var Analyzer = &analysis.Analyzer{ + Name: "durationcheck", + Doc: "check for two durations multiplied together", + Run: run, + Requires: []*analysis.Analyzer{inspect.Analyzer}, +} + +func run(pass *analysis.Pass) (interface{}, error) { + // if the package does not import time, it can be skipped from analysis + if !hasImport(pass.Pkg, "time") { + return nil, nil + } + + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + + nodeTypes := []ast.Node{ + (*ast.BinaryExpr)(nil), + } + + inspect.Preorder(nodeTypes, check(pass)) + + return nil, nil +} + +func hasImport(pkg *types.Package, importPath string) bool { + for _, imp := range pkg.Imports() { + if imp.Path() == importPath { + return true + } + } + + return false +} + +// check contains the logic for checking that time.Duration is used correctly in the code being analysed +func check(pass *analysis.Pass) func(ast.Node) { + return func(node ast.Node) { + expr := node.(*ast.BinaryExpr) + // we are only interested in multiplication + if expr.Op != token.MUL { + return + } + + // get the types of the two operands + x, xOK := pass.TypesInfo.Types[expr.X] + y, yOK := pass.TypesInfo.Types[expr.Y] + + if !xOK || !yOK { + return + } + + if isDuration(x.Type) && isDuration(y.Type) { + // check that both sides are acceptable expressions + if isUnacceptableExpr(pass, expr.X) && isUnacceptableExpr(pass, expr.Y) { + pass.Reportf(expr.Pos(), "Multiplication of durations: `%s`", formatNode(expr)) + } + } + } +} + +func isDuration(x types.Type) bool { + return x.String() == "time.Duration" +} + +// isUnacceptableExpr returns true if the argument is not an acceptable time.Duration expression +func isUnacceptableExpr(pass *analysis.Pass, expr ast.Expr) bool { + switch e := expr.(type) { + case *ast.BasicLit: + return false + case *ast.Ident: + return !isAcceptableNestedExpr(pass, e) + case *ast.CallExpr: + return !isAcceptableCast(pass, e) + case *ast.BinaryExpr: + return !isAcceptableNestedExpr(pass, e) + case *ast.UnaryExpr: + return !isAcceptableNestedExpr(pass, e) + case *ast.SelectorExpr: + return !isAcceptableNestedExpr(pass, e) + } + + return true +} + +// isAcceptableCast returns true if the argument is an acceptable expression cast to time.Duration +func isAcceptableCast(pass *analysis.Pass, e *ast.CallExpr) bool { + // check that there's a single argument + if len(e.Args) != 1 { + return false + } + + // check that the argument is acceptable + if !isAcceptableNestedExpr(pass, e.Args[0]) { + return false + } + + // check for time.Duration cast + selector, ok := e.Fun.(*ast.SelectorExpr) + if !ok { + return false + } + + return isDurationCast(selector) +} + +func isDurationCast(selector *ast.SelectorExpr) bool { + pkg, ok := selector.X.(*ast.Ident) + if !ok { + return false + } + + if pkg.Name != "time" { + return false + } + + return selector.Sel.Name == "Duration" +} + +func isAcceptableNestedExpr(pass *analysis.Pass, n ast.Expr) bool { + switch e := n.(type) { + case *ast.BasicLit: + return true + case *ast.BinaryExpr: + return isAcceptableNestedExpr(pass, e.X) && isAcceptableNestedExpr(pass, e.Y) + case *ast.UnaryExpr: + return isAcceptableNestedExpr(pass, e.X) + case *ast.Ident: + return isAcceptableIdent(pass, e) + case *ast.CallExpr: + t := pass.TypesInfo.TypeOf(e) + return !isDuration(t) + case *ast.SelectorExpr: + t := pass.TypesInfo.TypeOf(e) + return !isDuration(t) + } + + return false +} + +func isAcceptableIdent(pass *analysis.Pass, ident *ast.Ident) bool { + obj := pass.TypesInfo.ObjectOf(ident) + return !isDuration(obj.Type()) +} + +func formatNode(node ast.Node) string { + buf := new(bytes.Buffer) + if err := format.Node(buf, token.NewFileSet(), node); err != nil { + log.Printf("Error formatting expression: %v", err) + return "" + } + + return buf.String() +} + +func printAST(msg string, node ast.Node) { + fmt.Printf(">>> %s:\n%s\n\n\n", msg, formatNode(node)) + ast.Fprint(os.Stdout, nil, node, nil) + fmt.Println("--------------") +} |
