From 2ab72b4feef2c97f22f90cfbf9e45a6cfcd08bda Mon Sep 17 00:00:00 2001 From: Taras Madan Date: Tue, 5 Dec 2023 15:10:03 +0100 Subject: vendor: updates --- .../sqlclosecheck/pkg/analyzer/analyzer.go | 150 +++++++++++++++------ 1 file changed, 109 insertions(+), 41 deletions(-) (limited to 'vendor/github.com/ryanrolds/sqlclosecheck') diff --git a/vendor/github.com/ryanrolds/sqlclosecheck/pkg/analyzer/analyzer.go b/vendor/github.com/ryanrolds/sqlclosecheck/pkg/analyzer/analyzer.go index c22817caf..55e931a89 100644 --- a/vendor/github.com/ryanrolds/sqlclosecheck/pkg/analyzer/analyzer.go +++ b/vendor/github.com/ryanrolds/sqlclosecheck/pkg/analyzer/analyzer.go @@ -9,9 +9,10 @@ import ( ) const ( - rowsName = "Rows" - stmtName = "Stmt" - closeMethod = "Close" + rowsName = "Rows" + stmtName = "Stmt" + namedStmtName = "NamedStmt" + closeMethod = "Close" ) type action uint8 @@ -31,13 +32,15 @@ var ( sqlPackages = []string{ "database/sql", "github.com/jmoiron/sqlx", + "github.com/jackc/pgx/v5", + "github.com/jackc/pgx/v5/pgxpool", } ) func NewAnalyzer() *analysis.Analyzer { return &analysis.Analyzer{ Name: "sqlclosecheck", - Doc: "Checks that sql.Rows and sql.Stmt are closed.", + Doc: "Checks that sql.Rows, sql.Stmt, sqlx.NamedStmt, pgx.Query are closed.", Run: run, Requires: []*analysis.Analyzer{ buildssa.Analyzer, @@ -63,20 +66,18 @@ func run(pass *analysis.Pass) (interface{}, error) { for _, f := range funcs { for _, b := range f.Blocks { for i := range b.Instrs { - // Check if instruction is call that returns a target type + // Check if instruction is call that returns a target pointer type targetValues := getTargetTypesValues(b, i, targetTypes) if len(targetValues) == 0 { continue } - // log.Printf("%s", f.Name()) - // For each found target check if they are closed and deferred for _, targetValue := range targetValues { refs := (*targetValue.value).Referrers() isClosed := checkClosed(refs, targetTypes) if !isClosed { - pass.Reportf((targetValue.instr).Pos(), "Rows/Stmt was not closed") + pass.Reportf((targetValue.instr).Pos(), "Rows/Stmt/NamedStmt was not closed") } checkDeferred(pass, refs, targetTypes, false) @@ -88,17 +89,22 @@ func run(pass *analysis.Pass) (interface{}, error) { return nil, nil } -func getTargetTypes(pssa *buildssa.SSA, targetPackages []string) []*types.Pointer { - targets := []*types.Pointer{} +func getTargetTypes(pssa *buildssa.SSA, targetPackages []string) []any { + targets := []any{} for _, sqlPkg := range targetPackages { pkg := pssa.Pkg.Prog.ImportedPackage(sqlPkg) if pkg == nil { // the SQL package being checked isn't imported - return targets + continue + } + + rowsPtrType := getTypePointerFromName(pkg, rowsName) + if rowsPtrType != nil { + targets = append(targets, rowsPtrType) } - rowsType := getTypePointerFromName(pkg, rowsName) + rowsType := getTypeFromName(pkg, rowsName) if rowsType != nil { targets = append(targets, rowsType) } @@ -107,6 +113,11 @@ func getTargetTypes(pssa *buildssa.SSA, targetPackages []string) []*types.Pointe if stmtType != nil { targets = append(targets, stmtType) } + + namedStmtType := getTypePointerFromName(pkg, namedStmtName) + if namedStmtType != nil { + targets = append(targets, namedStmtType) + } } return targets @@ -115,7 +126,7 @@ func getTargetTypes(pssa *buildssa.SSA, targetPackages []string) []*types.Pointe func getTypePointerFromName(pkg *ssa.Package, name string) *types.Pointer { pkgType := pkg.Type(name) if pkgType == nil { - // this package does not use Rows/Stmt + // this package does not use Rows/Stmt/NamedStmt return nil } @@ -128,12 +139,28 @@ func getTypePointerFromName(pkg *ssa.Package, name string) *types.Pointer { return types.NewPointer(named) } +func getTypeFromName(pkg *ssa.Package, name string) *types.Named { + pkgType := pkg.Type(name) + if pkgType == nil { + // this package does not use Rows/Stmt + return nil + } + + obj := pkgType.Object() + named, ok := obj.Type().(*types.Named) + if !ok { + return nil + } + + return named +} + type targetValue struct { value *ssa.Value instr ssa.Instruction } -func getTargetTypesValues(b *ssa.BasicBlock, i int, targetTypes []*types.Pointer) []targetValue { +func getTargetTypesValues(b *ssa.BasicBlock, i int, targetTypes []any) []targetValue { targetValues := []targetValue{} instr := b.Instrs[i] @@ -149,21 +176,32 @@ func getTargetTypesValues(b *ssa.BasicBlock, i int, targetTypes []*types.Pointer varType := v.Type() for _, targetType := range targetTypes { - if !types.Identical(varType, targetType) { + var tt types.Type + + switch t := targetType.(type) { + case *types.Pointer: + tt = t + case *types.Named: + tt = t + default: + continue + } + + if !types.Identical(varType, tt) { continue } for _, cRef := range *call.Referrers() { switch instr := cRef.(type) { case *ssa.Call: - if len(instr.Call.Args) >= 1 && types.Identical(instr.Call.Args[0].Type(), targetType) { + if len(instr.Call.Args) >= 1 && types.Identical(instr.Call.Args[0].Type(), tt) { targetValues = append(targetValues, targetValue{ value: &instr.Call.Args[0], instr: call, }) } case ssa.Value: - if types.Identical(instr.Type(), targetType) { + if types.Identical(instr.Type(), tt) { targetValues = append(targetValues, targetValue{ value: &instr, instr: call, @@ -177,43 +215,42 @@ func getTargetTypesValues(b *ssa.BasicBlock, i int, targetTypes []*types.Pointer return targetValues } -func checkClosed(refs *[]ssa.Instruction, targetTypes []*types.Pointer) bool { +func checkClosed(refs *[]ssa.Instruction, targetTypes []any) bool { numInstrs := len(*refs) for idx, ref := range *refs { - // log.Printf("%T - %s", ref, ref) - action := getAction(ref, targetTypes) switch action { - case actionClosed: + case actionClosed, actionReturned, actionHandled: return true case actionPassed: // Passed and not used after if numInstrs == idx+1 { return true } - case actionReturned: - return true - case actionHandled: - return true - default: - // log.Printf(action) } } return false } -func getAction(instr ssa.Instruction, targetTypes []*types.Pointer) action { +func getAction(instr ssa.Instruction, targetTypes []any) action { switch instr := instr.(type) { case *ssa.Defer: - if instr.Call.Value == nil { - return actionUnvaluedDefer + if instr.Call.Value != nil { + name := instr.Call.Value.Name() + if name == closeMethod { + return actionClosed + } } - name := instr.Call.Value.Name() - if name == closeMethod { - return actionClosed + if instr.Call.Method != nil { + name := instr.Call.Method.Name() + if name == closeMethod { + return actionClosed + } } + + return actionUnvaluedDefer case *ssa.Call: if instr.Call.Value == nil { return actionUnvaluedCall @@ -265,7 +302,18 @@ func getAction(instr ssa.Instruction, targetTypes []*types.Pointer) action { case *ssa.UnOp: instrType := instr.Type() for _, targetType := range targetTypes { - if types.Identical(instrType, targetType) { + var tt types.Type + + switch t := targetType.(type) { + case *types.Pointer: + tt = t + case *types.Named: + tt = t + default: + continue + } + + if types.Identical(instrType, tt) { if checkClosed(instr.Referrers(), targetTypes) { return actionHandled } @@ -277,20 +325,22 @@ func getAction(instr ssa.Instruction, targetTypes []*types.Pointer) action { } case *ssa.Return: return actionReturned - default: - // log.Printf("%s", instr) } return actionUnhandled } -func checkDeferred(pass *analysis.Pass, instrs *[]ssa.Instruction, targetTypes []*types.Pointer, inDefer bool) { +func checkDeferred(pass *analysis.Pass, instrs *[]ssa.Instruction, targetTypes []any, inDefer bool) { for _, instr := range *instrs { switch instr := instr.(type) { case *ssa.Defer: if instr.Call.Value != nil && instr.Call.Value.Name() == closeMethod { return } + + if instr.Call.Method != nil && instr.Call.Method.Name() == closeMethod { + return + } case *ssa.Call: if instr.Call.Value != nil && instr.Call.Value.Name() == closeMethod { if !inDefer { @@ -316,7 +366,18 @@ func checkDeferred(pass *analysis.Pass, instrs *[]ssa.Instruction, targetTypes [ case *ssa.UnOp: instrType := instr.Type() for _, targetType := range targetTypes { - if types.Identical(instrType, targetType) { + var tt types.Type + + switch t := targetType.(type) { + case *types.Pointer: + tt = t + case *types.Named: + tt = t + default: + continue + } + + if types.Identical(instrType, tt) { checkDeferred(pass, instr.Referrers(), targetTypes, inDefer) } } @@ -326,10 +387,17 @@ func checkDeferred(pass *analysis.Pass, instrs *[]ssa.Instruction, targetTypes [ } } -func isTargetType(t types.Type, targetTypes []*types.Pointer) bool { +func isTargetType(t types.Type, targetTypes []any) bool { for _, targetType := range targetTypes { - if types.Identical(t, targetType) { - return true + switch tt := targetType.(type) { + case *types.Pointer: + if types.Identical(t, tt) { + return true + } + case *types.Named: + if types.Identical(t, tt) { + return true + } } } -- cgit mrf-deployment