aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/Djarvur/go-err113/comparison.go
blob: 0ffe2863c4c5f5b6d606ecf0e4adaf63ac090954 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package err113

import (
	"fmt"
	"go/ast"
	"go/token"
	"go/types"

	"golang.org/x/tools/go/analysis"
)

func inspectComparision(pass *analysis.Pass, n ast.Node) bool { // nolint: unparam
	// check whether the call expression matches time.Now().Sub()
	be, ok := n.(*ast.BinaryExpr)
	if !ok {
		return true
	}

	// check if it is a comparison operation
	if be.Op != token.EQL && be.Op != token.NEQ {
		return true
	}

	if !areBothErrors(be.X, be.Y, pass.TypesInfo) {
		return true
	}

	oldExpr := render(pass.Fset, be)

	negate := ""
	if be.Op == token.NEQ {
		negate = "!"
	}

	newExpr := fmt.Sprintf("%s%s.Is(%s, %s)", negate, "errors", be.X, be.Y)

	pass.Report(
		analysis.Diagnostic{
			Pos:     be.Pos(),
			Message: fmt.Sprintf("do not compare errors directly, use errors.Is() instead: %q", oldExpr),
			SuggestedFixes: []analysis.SuggestedFix{
				{
					Message: fmt.Sprintf("should replace %q with %q", oldExpr, newExpr),
					TextEdits: []analysis.TextEdit{
						{
							Pos:     be.Pos(),
							End:     be.End(),
							NewText: []byte(newExpr),
						},
					},
				},
			},
		},
	)

	return true
}

func isError(v ast.Expr, info *types.Info) bool {
	if intf, ok := info.TypeOf(v).Underlying().(*types.Interface); ok {
		return intf.NumMethods() == 1 && intf.Method(0).FullName() == "(error).Error"
	}

	return false
}

func isEOF(ex ast.Expr, info *types.Info) bool {
	se, ok := ex.(*ast.SelectorExpr)
	if !ok || se.Sel.Name != "EOF" {
		return false
	}

	if ep, ok := asImportedName(se.X, info); !ok || ep != "io" {
		return false
	}

	return true
}

func asImportedName(ex ast.Expr, info *types.Info) (string, bool) {
	ei, ok := ex.(*ast.Ident)
	if !ok {
		return "", false
	}

	ep, ok := info.ObjectOf(ei).(*types.PkgName)
	if !ok {
		return "", false
	}

	return ep.Imported().Path(), true
}

func areBothErrors(x, y ast.Expr, typesInfo *types.Info) bool {
	// check that both left and right hand side are not nil
	if typesInfo.Types[x].IsNil() || typesInfo.Types[y].IsNil() {
		return false
	}

	// check that both left and right hand side are not io.EOF
	if isEOF(x, typesInfo) || isEOF(y, typesInfo) {
		return false
	}

	// check that both left and right hand side are errors
	if !isError(x, typesInfo) && !isError(y, typesInfo) {
		return false
	}

	return true
}