aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/ryanrolds/sqlclosecheck/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/ryanrolds/sqlclosecheck/pkg')
-rw-r--r--vendor/github.com/ryanrolds/sqlclosecheck/pkg/analyzer/analyzer.go150
1 files changed, 109 insertions, 41 deletions
diff --git a/vendor/github.com/ryanrolds/sqlclosecheck/pkg/analyzer/analyzer.go b/vendor/github.com/ryanrolds/sqlclosecheck/pkg/analyzer/analyzer.go
index c22817caf..55e931a89 100644
--- a/vendor/github.com/ryanrolds/sqlclosecheck/pkg/analyzer/analyzer.go
+++ b/vendor/github.com/ryanrolds/sqlclosecheck/pkg/analyzer/analyzer.go
@@ -9,9 +9,10 @@ import (
)
const (
- rowsName = "Rows"
- stmtName = "Stmt"
- closeMethod = "Close"
+ rowsName = "Rows"
+ stmtName = "Stmt"
+ namedStmtName = "NamedStmt"
+ closeMethod = "Close"
)
type action uint8
@@ -31,13 +32,15 @@ var (
sqlPackages = []string{
"database/sql",
"github.com/jmoiron/sqlx",
+ "github.com/jackc/pgx/v5",
+ "github.com/jackc/pgx/v5/pgxpool",
}
)
func NewAnalyzer() *analysis.Analyzer {
return &analysis.Analyzer{
Name: "sqlclosecheck",
- Doc: "Checks that sql.Rows and sql.Stmt are closed.",
+ Doc: "Checks that sql.Rows, sql.Stmt, sqlx.NamedStmt, pgx.Query are closed.",
Run: run,
Requires: []*analysis.Analyzer{
buildssa.Analyzer,
@@ -63,20 +66,18 @@ func run(pass *analysis.Pass) (interface{}, error) {
for _, f := range funcs {
for _, b := range f.Blocks {
for i := range b.Instrs {
- // Check if instruction is call that returns a target type
+ // Check if instruction is call that returns a target pointer type
targetValues := getTargetTypesValues(b, i, targetTypes)
if len(targetValues) == 0 {
continue
}
- // log.Printf("%s", f.Name())
-
// For each found target check if they are closed and deferred
for _, targetValue := range targetValues {
refs := (*targetValue.value).Referrers()
isClosed := checkClosed(refs, targetTypes)
if !isClosed {
- pass.Reportf((targetValue.instr).Pos(), "Rows/Stmt was not closed")
+ pass.Reportf((targetValue.instr).Pos(), "Rows/Stmt/NamedStmt was not closed")
}
checkDeferred(pass, refs, targetTypes, false)
@@ -88,17 +89,22 @@ func run(pass *analysis.Pass) (interface{}, error) {
return nil, nil
}
-func getTargetTypes(pssa *buildssa.SSA, targetPackages []string) []*types.Pointer {
- targets := []*types.Pointer{}
+func getTargetTypes(pssa *buildssa.SSA, targetPackages []string) []any {
+ targets := []any{}
for _, sqlPkg := range targetPackages {
pkg := pssa.Pkg.Prog.ImportedPackage(sqlPkg)
if pkg == nil {
// the SQL package being checked isn't imported
- return targets
+ continue
+ }
+
+ rowsPtrType := getTypePointerFromName(pkg, rowsName)
+ if rowsPtrType != nil {
+ targets = append(targets, rowsPtrType)
}
- rowsType := getTypePointerFromName(pkg, rowsName)
+ rowsType := getTypeFromName(pkg, rowsName)
if rowsType != nil {
targets = append(targets, rowsType)
}
@@ -107,6 +113,11 @@ func getTargetTypes(pssa *buildssa.SSA, targetPackages []string) []*types.Pointe
if stmtType != nil {
targets = append(targets, stmtType)
}
+
+ namedStmtType := getTypePointerFromName(pkg, namedStmtName)
+ if namedStmtType != nil {
+ targets = append(targets, namedStmtType)
+ }
}
return targets
@@ -115,7 +126,7 @@ func getTargetTypes(pssa *buildssa.SSA, targetPackages []string) []*types.Pointe
func getTypePointerFromName(pkg *ssa.Package, name string) *types.Pointer {
pkgType := pkg.Type(name)
if pkgType == nil {
- // this package does not use Rows/Stmt
+ // this package does not use Rows/Stmt/NamedStmt
return nil
}
@@ -128,12 +139,28 @@ func getTypePointerFromName(pkg *ssa.Package, name string) *types.Pointer {
return types.NewPointer(named)
}
+func getTypeFromName(pkg *ssa.Package, name string) *types.Named {
+ pkgType := pkg.Type(name)
+ if pkgType == nil {
+ // this package does not use Rows/Stmt
+ return nil
+ }
+
+ obj := pkgType.Object()
+ named, ok := obj.Type().(*types.Named)
+ if !ok {
+ return nil
+ }
+
+ return named
+}
+
type targetValue struct {
value *ssa.Value
instr ssa.Instruction
}
-func getTargetTypesValues(b *ssa.BasicBlock, i int, targetTypes []*types.Pointer) []targetValue {
+func getTargetTypesValues(b *ssa.BasicBlock, i int, targetTypes []any) []targetValue {
targetValues := []targetValue{}
instr := b.Instrs[i]
@@ -149,21 +176,32 @@ func getTargetTypesValues(b *ssa.BasicBlock, i int, targetTypes []*types.Pointer
varType := v.Type()
for _, targetType := range targetTypes {
- if !types.Identical(varType, targetType) {
+ var tt types.Type
+
+ switch t := targetType.(type) {
+ case *types.Pointer:
+ tt = t
+ case *types.Named:
+ tt = t
+ default:
+ continue
+ }
+
+ if !types.Identical(varType, tt) {
continue
}
for _, cRef := range *call.Referrers() {
switch instr := cRef.(type) {
case *ssa.Call:
- if len(instr.Call.Args) >= 1 && types.Identical(instr.Call.Args[0].Type(), targetType) {
+ if len(instr.Call.Args) >= 1 && types.Identical(instr.Call.Args[0].Type(), tt) {
targetValues = append(targetValues, targetValue{
value: &instr.Call.Args[0],
instr: call,
})
}
case ssa.Value:
- if types.Identical(instr.Type(), targetType) {
+ if types.Identical(instr.Type(), tt) {
targetValues = append(targetValues, targetValue{
value: &instr,
instr: call,
@@ -177,43 +215,42 @@ func getTargetTypesValues(b *ssa.BasicBlock, i int, targetTypes []*types.Pointer
return targetValues
}
-func checkClosed(refs *[]ssa.Instruction, targetTypes []*types.Pointer) bool {
+func checkClosed(refs *[]ssa.Instruction, targetTypes []any) bool {
numInstrs := len(*refs)
for idx, ref := range *refs {
- // log.Printf("%T - %s", ref, ref)
-
action := getAction(ref, targetTypes)
switch action {
- case actionClosed:
+ case actionClosed, actionReturned, actionHandled:
return true
case actionPassed:
// Passed and not used after
if numInstrs == idx+1 {
return true
}
- case actionReturned:
- return true
- case actionHandled:
- return true
- default:
- // log.Printf(action)
}
}
return false
}
-func getAction(instr ssa.Instruction, targetTypes []*types.Pointer) action {
+func getAction(instr ssa.Instruction, targetTypes []any) action {
switch instr := instr.(type) {
case *ssa.Defer:
- if instr.Call.Value == nil {
- return actionUnvaluedDefer
+ if instr.Call.Value != nil {
+ name := instr.Call.Value.Name()
+ if name == closeMethod {
+ return actionClosed
+ }
}
- name := instr.Call.Value.Name()
- if name == closeMethod {
- return actionClosed
+ if instr.Call.Method != nil {
+ name := instr.Call.Method.Name()
+ if name == closeMethod {
+ return actionClosed
+ }
}
+
+ return actionUnvaluedDefer
case *ssa.Call:
if instr.Call.Value == nil {
return actionUnvaluedCall
@@ -265,7 +302,18 @@ func getAction(instr ssa.Instruction, targetTypes []*types.Pointer) action {
case *ssa.UnOp:
instrType := instr.Type()
for _, targetType := range targetTypes {
- if types.Identical(instrType, targetType) {
+ var tt types.Type
+
+ switch t := targetType.(type) {
+ case *types.Pointer:
+ tt = t
+ case *types.Named:
+ tt = t
+ default:
+ continue
+ }
+
+ if types.Identical(instrType, tt) {
if checkClosed(instr.Referrers(), targetTypes) {
return actionHandled
}
@@ -277,20 +325,22 @@ func getAction(instr ssa.Instruction, targetTypes []*types.Pointer) action {
}
case *ssa.Return:
return actionReturned
- default:
- // log.Printf("%s", instr)
}
return actionUnhandled
}
-func checkDeferred(pass *analysis.Pass, instrs *[]ssa.Instruction, targetTypes []*types.Pointer, inDefer bool) {
+func checkDeferred(pass *analysis.Pass, instrs *[]ssa.Instruction, targetTypes []any, inDefer bool) {
for _, instr := range *instrs {
switch instr := instr.(type) {
case *ssa.Defer:
if instr.Call.Value != nil && instr.Call.Value.Name() == closeMethod {
return
}
+
+ if instr.Call.Method != nil && instr.Call.Method.Name() == closeMethod {
+ return
+ }
case *ssa.Call:
if instr.Call.Value != nil && instr.Call.Value.Name() == closeMethod {
if !inDefer {
@@ -316,7 +366,18 @@ func checkDeferred(pass *analysis.Pass, instrs *[]ssa.Instruction, targetTypes [
case *ssa.UnOp:
instrType := instr.Type()
for _, targetType := range targetTypes {
- if types.Identical(instrType, targetType) {
+ var tt types.Type
+
+ switch t := targetType.(type) {
+ case *types.Pointer:
+ tt = t
+ case *types.Named:
+ tt = t
+ default:
+ continue
+ }
+
+ if types.Identical(instrType, tt) {
checkDeferred(pass, instr.Referrers(), targetTypes, inDefer)
}
}
@@ -326,10 +387,17 @@ func checkDeferred(pass *analysis.Pass, instrs *[]ssa.Instruction, targetTypes [
}
}
-func isTargetType(t types.Type, targetTypes []*types.Pointer) bool {
+func isTargetType(t types.Type, targetTypes []any) bool {
for _, targetType := range targetTypes {
- if types.Identical(t, targetType) {
- return true
+ switch tt := targetType.(type) {
+ case *types.Pointer:
+ if types.Identical(t, tt) {
+ return true
+ }
+ case *types.Named:
+ if types.Identical(t, tt) {
+ return true
+ }
}
}