aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/ryanrolds/sqlclosecheck/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/ryanrolds/sqlclosecheck/pkg')
-rw-r--r--vendor/github.com/ryanrolds/sqlclosecheck/pkg/analyzer/analyzer.go311
1 files changed, 311 insertions, 0 deletions
diff --git a/vendor/github.com/ryanrolds/sqlclosecheck/pkg/analyzer/analyzer.go b/vendor/github.com/ryanrolds/sqlclosecheck/pkg/analyzer/analyzer.go
new file mode 100644
index 000000000..bc42dfb3a
--- /dev/null
+++ b/vendor/github.com/ryanrolds/sqlclosecheck/pkg/analyzer/analyzer.go
@@ -0,0 +1,311 @@
+package analyzer
+
+import (
+ "go/types"
+
+ "golang.org/x/tools/go/analysis"
+ "golang.org/x/tools/go/analysis/passes/buildssa"
+ "golang.org/x/tools/go/ssa"
+)
+
+const (
+ rowsName = "Rows"
+ stmtName = "Stmt"
+ closeMethod = "Close"
+)
+
+var (
+ sqlPackages = []string{
+ "database/sql",
+ "github.com/jmoiron/sqlx",
+ }
+)
+
+func NewAnalyzer() *analysis.Analyzer {
+ return &analysis.Analyzer{
+ Name: "sqlclosecheck",
+ Doc: "Checks that sql.Rows and sql.Stmt are closed.",
+ Run: run,
+ Requires: []*analysis.Analyzer{
+ buildssa.Analyzer,
+ },
+ }
+}
+
+func run(pass *analysis.Pass) (interface{}, error) {
+ pssa := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA)
+
+ // Build list of types we are looking for
+ targetTypes := getTargetTypes(pssa, sqlPackages)
+
+ // If non of the types are found, skip
+ if len(targetTypes) == 0 {
+ return nil, nil
+ }
+
+ funcs := pssa.SrcFuncs
+ for _, f := range funcs {
+ for _, b := range f.Blocks {
+ for i := range b.Instrs {
+ // Check if instruction is call that returns a target type
+ targetValues := getTargetTypesValues(b, i, targetTypes)
+ if len(targetValues) == 0 {
+ continue
+ }
+
+ // log.Printf("%s", f.Name())
+
+ // For each found target check if they are closed and deferred
+ for _, targetValue := range targetValues {
+ refs := (*targetValue.value).Referrers()
+ isClosed := checkClosed(refs, targetTypes)
+ if !isClosed {
+ pass.Reportf((targetValue.instr).Pos(), "Rows/Stmt was not closed")
+ }
+
+ checkDeferred(pass, refs, targetTypes, false)
+ }
+ }
+ }
+ }
+
+ return nil, nil
+}
+
+func getTargetTypes(pssa *buildssa.SSA, targetPackages []string) []*types.Pointer {
+ targets := []*types.Pointer{}
+
+ for _, sqlPkg := range targetPackages {
+ pkg := pssa.Pkg.Prog.ImportedPackage(sqlPkg)
+ if pkg == nil {
+ // the SQL package being checked isn't imported
+ return targets
+ }
+
+ rowsType := getTypePointerFromName(pkg, rowsName)
+ if rowsType != nil {
+ targets = append(targets, rowsType)
+ }
+
+ stmtType := getTypePointerFromName(pkg, stmtName)
+ if stmtType != nil {
+ targets = append(targets, stmtType)
+ }
+ }
+
+ return targets
+}
+
+func getTypePointerFromName(pkg *ssa.Package, name string) *types.Pointer {
+ pkgType := pkg.Type(name)
+ if pkgType == nil {
+ // this package does not use Rows/Stmt
+ return nil
+ }
+
+ obj := pkgType.Object()
+ named, ok := obj.Type().(*types.Named)
+ if !ok {
+ return nil
+ }
+
+ return types.NewPointer(named)
+}
+
+type targetValue struct {
+ value *ssa.Value
+ instr ssa.Instruction
+}
+
+func getTargetTypesValues(b *ssa.BasicBlock, i int, targetTypes []*types.Pointer) []targetValue {
+ targetValues := []targetValue{}
+
+ instr := b.Instrs[i]
+ call, ok := instr.(*ssa.Call)
+ if !ok {
+ return targetValues
+ }
+
+ signature := call.Call.Signature()
+ results := signature.Results()
+ for i := 0; i < results.Len(); i++ {
+ v := results.At(i)
+ varType := v.Type()
+
+ for _, targetType := range targetTypes {
+ if !types.Identical(varType, targetType) {
+ continue
+ }
+
+ for _, cRef := range *call.Referrers() {
+ switch instr := cRef.(type) {
+ case *ssa.Call:
+ if len(instr.Call.Args) >= 1 && types.Identical(instr.Call.Args[0].Type(), targetType) {
+ targetValues = append(targetValues, targetValue{
+ value: &instr.Call.Args[0],
+ instr: call,
+ })
+ }
+ case ssa.Value:
+ if types.Identical(instr.Type(), targetType) {
+ targetValues = append(targetValues, targetValue{
+ value: &instr,
+ instr: call,
+ })
+ }
+ }
+ }
+ }
+ }
+
+ return targetValues
+}
+
+func checkClosed(refs *[]ssa.Instruction, targetTypes []*types.Pointer) bool {
+ numInstrs := len(*refs)
+ for idx, ref := range *refs {
+ // log.Printf("%T - %s", ref, ref)
+
+ action := getAction(ref, targetTypes)
+ switch action {
+ case "closed":
+ return true
+ case "passed":
+ // Passed and not used after
+ if numInstrs == idx+1 {
+ return true
+ }
+ case "returned":
+ return true
+ case "handled":
+ return true
+ default:
+ // log.Printf(action)
+ }
+ }
+
+ return false
+}
+
+func getAction(instr ssa.Instruction, targetTypes []*types.Pointer) string {
+ switch instr := instr.(type) {
+ case *ssa.Defer:
+ if instr.Call.Value == nil {
+ return "unvalued defer"
+ }
+
+ name := instr.Call.Value.Name()
+ if name == closeMethod {
+ return "closed"
+ }
+ case *ssa.Call:
+ if instr.Call.Value == nil {
+ return "unvalued call"
+ }
+
+ isTarget := false
+ receiver := instr.Call.StaticCallee().Signature.Recv()
+ if receiver != nil {
+ isTarget = isTargetType(receiver.Type(), targetTypes)
+ }
+
+ name := instr.Call.Value.Name()
+ if isTarget && name == closeMethod {
+ return "closed"
+ }
+
+ if !isTarget {
+ return "passed"
+ }
+ case *ssa.Phi:
+ return "passed"
+ case *ssa.MakeInterface:
+ return "passed"
+ case *ssa.Store:
+ if len(*instr.Addr.Referrers()) == 0 {
+ return "noop"
+ }
+
+ for _, aRef := range *instr.Addr.Referrers() {
+ if c, ok := aRef.(*ssa.MakeClosure); ok {
+ f := c.Fn.(*ssa.Function)
+ for _, b := range f.Blocks {
+ if checkClosed(&b.Instrs, targetTypes) {
+ return "handled"
+ }
+ }
+ }
+ }
+ case *ssa.UnOp:
+ instrType := instr.Type()
+ for _, targetType := range targetTypes {
+ if types.Identical(instrType, targetType) {
+ if checkClosed(instr.Referrers(), targetTypes) {
+ return "handled"
+ }
+ }
+ }
+ case *ssa.FieldAddr:
+ if checkClosed(instr.Referrers(), targetTypes) {
+ return "handled"
+ }
+ case *ssa.Return:
+ return "returned"
+ default:
+ // log.Printf("%s", instr)
+ }
+
+ return "unhandled"
+}
+
+func checkDeferred(pass *analysis.Pass, instrs *[]ssa.Instruction, targetTypes []*types.Pointer, inDefer bool) {
+ for _, instr := range *instrs {
+ switch instr := instr.(type) {
+ case *ssa.Defer:
+ if instr.Call.Value != nil && instr.Call.Value.Name() == closeMethod {
+ return
+ }
+ case *ssa.Call:
+ if instr.Call.Value != nil && instr.Call.Value.Name() == closeMethod {
+ if !inDefer {
+ pass.Reportf(instr.Pos(), "Close should use defer")
+ }
+
+ return
+ }
+ case *ssa.Store:
+ if len(*instr.Addr.Referrers()) == 0 {
+ return
+ }
+
+ for _, aRef := range *instr.Addr.Referrers() {
+ if c, ok := aRef.(*ssa.MakeClosure); ok {
+ f := c.Fn.(*ssa.Function)
+
+ for _, b := range f.Blocks {
+ checkDeferred(pass, &b.Instrs, targetTypes, true)
+ }
+ }
+ }
+ case *ssa.UnOp:
+ instrType := instr.Type()
+ for _, targetType := range targetTypes {
+ if types.Identical(instrType, targetType) {
+ checkDeferred(pass, instr.Referrers(), targetTypes, inDefer)
+ }
+ }
+ case *ssa.FieldAddr:
+ checkDeferred(pass, instr.Referrers(), targetTypes, inDefer)
+ }
+ }
+}
+
+func isTargetType(t types.Type, targetTypes []*types.Pointer) bool {
+ for _, targetType := range targetTypes {
+ if types.Identical(t, targetType) {
+ return true
+ }
+ }
+
+ return false
+}