aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/stretchr/testify/mock
diff options
context:
space:
mode:
authorTaras Madan <tarasmadan@google.com>2022-09-05 14:27:54 +0200
committerGitHub <noreply@github.com>2022-09-05 12:27:54 +0000
commitb2f2446b46bf02821d90ebedadae2bf7ae0e880e (patch)
tree923cf42842918d6bebca1d6bbdc08abed54d274d /vendor/github.com/stretchr/testify/mock
parente6654faff4bcca4be92e9a8596fd4b77f747c39e (diff)
go.mod, vendor: update (#3358)
* go.mod, vendor: remove unnecessary dependencies Commands: 1. go mod tidy 2. go mod vendor * go.mod, vendor: update cloud.google.com/go Commands: 1. go get -u cloud.google.com/go 2. go mod tidy 3. go mod vendor * go.mod, vendor: update cloud.google.com/* Commands: 1. go get -u cloud.google.com/storage cloud.google.com/logging 2. go mod tidy 3. go mod vendor * go.mod, .golangci.yml, vendor: update *lint* Commands: 1. go get -u golang.org/x/tools github.com/golangci/golangci-lint@v1.47.0 2. go mod tidy 3. go mod vendor 4. edit .golangci.yml to suppress new errors (resolved in the same PR later) * all: fix lint errors hash.go: copy() recommended by gosimple parse.go: ent is never nil verifier.go: signal.Notify() with unbuffered channel is bad. Have no idea why. * .golangci.yml: adjust godot rules check-all is deprecated, but still work if you're hesitating too - I'll remove this commit
Diffstat (limited to 'vendor/github.com/stretchr/testify/mock')
-rw-r--r--vendor/github.com/stretchr/testify/mock/mock.go162
1 files changed, 126 insertions, 36 deletions
diff --git a/vendor/github.com/stretchr/testify/mock/mock.go b/vendor/github.com/stretchr/testify/mock/mock.go
index e2e6a2d23..f0af8246c 100644
--- a/vendor/github.com/stretchr/testify/mock/mock.go
+++ b/vendor/github.com/stretchr/testify/mock/mock.go
@@ -70,6 +70,9 @@ type Call struct {
// if the PanicMsg is set to a non nil string the function call will panic
// irrespective of other settings
PanicMsg *string
+
+ // Calls which must be satisfied before this call can be
+ requires []*Call
}
func newCall(parent *Mock, methodName string, callerInfo []string, methodArguments ...interface{}) *Call {
@@ -199,6 +202,64 @@ func (c *Call) On(methodName string, arguments ...interface{}) *Call {
return c.Parent.On(methodName, arguments...)
}
+// Unset removes a mock handler from being called.
+// test.On("func", mock.Anything).Unset()
+func (c *Call) Unset() *Call {
+ var unlockOnce sync.Once
+
+ for _, arg := range c.Arguments {
+ if v := reflect.ValueOf(arg); v.Kind() == reflect.Func {
+ panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg))
+ }
+ }
+
+ c.lock()
+ defer unlockOnce.Do(c.unlock)
+
+ foundMatchingCall := false
+
+ for i, call := range c.Parent.ExpectedCalls {
+ if call.Method == c.Method {
+ _, diffCount := call.Arguments.Diff(c.Arguments)
+ if diffCount == 0 {
+ foundMatchingCall = true
+ // Remove from ExpectedCalls
+ c.Parent.ExpectedCalls = append(c.Parent.ExpectedCalls[:i], c.Parent.ExpectedCalls[i+1:]...)
+ }
+ }
+ }
+
+ if !foundMatchingCall {
+ unlockOnce.Do(c.unlock)
+ c.Parent.fail("\n\nmock: Could not find expected call\n-----------------------------\n\n%s\n\n",
+ callString(c.Method, c.Arguments, true),
+ )
+ }
+
+ return c
+}
+
+// NotBefore indicates that the mock should only be called after the referenced
+// calls have been called as expected. The referenced calls may be from the
+// same mock instance and/or other mock instances.
+//
+// Mock.On("Do").Return(nil).Notbefore(
+// Mock.On("Init").Return(nil)
+// )
+func (c *Call) NotBefore(calls ...*Call) *Call {
+ c.lock()
+ defer c.unlock()
+
+ for _, call := range calls {
+ if call.Parent == nil {
+ panic("not before calls must be created with Mock.On()")
+ }
+ }
+
+ c.requires = append(c.requires, calls...)
+ return c
+}
+
// Mock is the workhorse used to track activity on another object.
// For an example of its usage, refer to the "Example Usage" section at the top
// of this document.
@@ -221,10 +282,17 @@ type Mock struct {
mutex sync.Mutex
}
+// String provides a %v format string for Mock.
+// Note: this is used implicitly by Arguments.Diff if a Mock is passed.
+// It exists because go's default %v formatting traverses the struct
+// without acquiring the mutex, which is detected by go test -race.
+func (m *Mock) String() string {
+ return fmt.Sprintf("%[1]T<%[1]p>", m)
+}
+
// TestData holds any data that might be useful for testing. Testify ignores
// this data completely allowing you to do whatever you like with it.
func (m *Mock) TestData() objx.Map {
-
if m.testData == nil {
m.testData = make(objx.Map)
}
@@ -346,7 +414,6 @@ func (m *Mock) findClosestCall(method string, arguments ...interface{}) (*Call,
}
func callString(method string, arguments Arguments, includeArgumentValues bool) string {
-
var argValsString string
if includeArgumentValues {
var argVals []string
@@ -370,10 +437,10 @@ func (m *Mock) Called(arguments ...interface{}) Arguments {
panic("Couldn't get the caller information")
}
functionPath := runtime.FuncForPC(pc).Name()
- //Next four lines are required to use GCCGO function naming conventions.
- //For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock
- //uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree
- //With GCCGO we need to remove interface information starting from pN<dd>.
+ // Next four lines are required to use GCCGO function naming conventions.
+ // For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock
+ // uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree
+ // With GCCGO we need to remove interface information starting from pN<dd>.
re := regexp.MustCompile("\\.pN\\d+_")
if re.MatchString(functionPath) {
functionPath = re.Split(functionPath, -1)[0]
@@ -389,7 +456,7 @@ func (m *Mock) Called(arguments ...interface{}) Arguments {
// If Call.WaitFor is set, blocks until the channel is closed or receives a message.
func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Arguments {
m.mutex.Lock()
- //TODO: could combine expected and closes in single loop
+ // TODO: could combine expected and closes in single loop
found, call := m.findExpectedCall(methodName, arguments...)
if found < 0 {
@@ -419,6 +486,25 @@ func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Argumen
}
}
+ for _, requirement := range call.requires {
+ if satisfied, _ := requirement.Parent.checkExpectation(requirement); !satisfied {
+ m.mutex.Unlock()
+ m.fail("mock: Unexpected Method Call\n-----------------------------\n\n%s\n\nMust not be called before%s:\n\n%s",
+ callString(call.Method, call.Arguments, true),
+ func() (s string) {
+ if requirement.totalCalls > 0 {
+ s = " another call of"
+ }
+ if call.Parent != requirement.Parent {
+ s += " method from another mock instance"
+ }
+ return
+ }(),
+ callString(requirement.Method, requirement.Arguments, true),
+ )
+ }
+ }
+
if call.Repeatability == 1 {
call.Repeatability = -1
} else if call.Repeatability > 1 {
@@ -476,9 +562,9 @@ func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool {
h.Helper()
}
for _, obj := range testObjects {
- if m, ok := obj.(Mock); ok {
+ if m, ok := obj.(*Mock); ok {
t.Logf("Deprecated mock.AssertExpectationsForObjects(myMock.Mock) use mock.AssertExpectationsForObjects(myMock)")
- obj = &m
+ obj = m
}
m := obj.(assertExpectationser)
if !m.AssertExpectations(t) {
@@ -495,34 +581,36 @@ func (m *Mock) AssertExpectations(t TestingT) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
+
m.mutex.Lock()
defer m.mutex.Unlock()
- var somethingMissing bool
var failedExpectations int
// iterate through each expectation
expectedCalls := m.expectedCalls()
for _, expectedCall := range expectedCalls {
- if !expectedCall.optional && !m.methodWasCalled(expectedCall.Method, expectedCall.Arguments) && expectedCall.totalCalls == 0 {
- somethingMissing = true
+ satisfied, reason := m.checkExpectation(expectedCall)
+ if !satisfied {
failedExpectations++
- t.Logf("FAIL:\t%s(%s)\n\t\tat: %s", expectedCall.Method, expectedCall.Arguments.String(), expectedCall.callerInfo)
- } else {
- if expectedCall.Repeatability > 0 {
- somethingMissing = true
- failedExpectations++
- t.Logf("FAIL:\t%s(%s)\n\t\tat: %s", expectedCall.Method, expectedCall.Arguments.String(), expectedCall.callerInfo)
- } else {
- t.Logf("PASS:\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String())
- }
}
+ t.Logf(reason)
}
- if somethingMissing {
+ if failedExpectations != 0 {
t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo())
}
- return !somethingMissing
+ return failedExpectations == 0
+}
+
+func (m *Mock) checkExpectation(call *Call) (bool, string) {
+ if !call.optional && !m.methodWasCalled(call.Method, call.Arguments) && call.totalCalls == 0 {
+ return false, fmt.Sprintf("FAIL:\t%s(%s)\n\t\tat: %s", call.Method, call.Arguments.String(), call.callerInfo)
+ }
+ if call.Repeatability > 0 {
+ return false, fmt.Sprintf("FAIL:\t%s(%s)\n\t\tat: %s", call.Method, call.Arguments.String(), call.callerInfo)
+ }
+ return true, fmt.Sprintf("PASS:\t%s(%s)", call.Method, call.Arguments.String())
}
// AssertNumberOfCalls asserts that the method was called expectedCalls times.
@@ -720,7 +808,7 @@ func (f argumentMatcher) Matches(argument interface{}) bool {
}
func (f argumentMatcher) String() string {
- return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).Name())
+ return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).String())
}
// MatchedBy can be used to match a mock call based on only certain properties
@@ -773,12 +861,12 @@ func (args Arguments) Is(objects ...interface{}) bool {
//
// Returns the diff string and number of differences found.
func (args Arguments) Diff(objects []interface{}) (string, int) {
- //TODO: could return string as error and nil for No difference
+ // TODO: could return string as error and nil for No difference
- var output = "\n"
+ output := "\n"
var differences int
- var maxArgCount = len(args)
+ maxArgCount := len(args)
if len(objects) > maxArgCount {
maxArgCount = len(objects)
}
@@ -804,21 +892,28 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
}
if matcher, ok := expected.(argumentMatcher); ok {
- if matcher.Matches(actual) {
+ var matches bool
+ func() {
+ defer func() {
+ if r := recover(); r != nil {
+ actualFmt = fmt.Sprintf("panic in argument matcher: %v", r)
+ }
+ }()
+ matches = matcher.Matches(actual)
+ }()
+ if matches {
output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher)
} else {
differences++
output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher)
}
} else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() {
-
// type checking
if reflect.TypeOf(actual).Name() != string(expected.(AnythingOfTypeArgument)) && reflect.TypeOf(actual).String() != string(expected.(AnythingOfTypeArgument)) {
// not match
differences++
output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt)
}
-
} else if reflect.TypeOf(expected) == reflect.TypeOf((*IsTypeArgument)(nil)) {
t := expected.(*IsTypeArgument).t
if reflect.TypeOf(t) != reflect.TypeOf(actual) {
@@ -826,7 +921,6 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, reflect.TypeOf(t).Name(), reflect.TypeOf(actual).Name(), actualFmt)
}
} else {
-
// normal checking
if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
@@ -846,7 +940,6 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
}
return output, differences
-
}
// Assert compares the arguments with the specified objects and fails if
@@ -868,7 +961,6 @@ func (args Arguments) Assert(t TestingT, objects ...interface{}) bool {
t.Errorf("%sArguments do not match.", assert.CallerInfo())
return false
-
}
// String gets the argument at the specified index. Panics if there is no argument, or
@@ -877,7 +969,6 @@ func (args Arguments) Assert(t TestingT, objects ...interface{}) bool {
// If no index is provided, String() returns a complete string representation
// of the arguments.
func (args Arguments) String(indexOrNil ...int) string {
-
if len(indexOrNil) == 0 {
// normal String() method - return a string representation of the args
var argsStr []string
@@ -887,7 +978,7 @@ func (args Arguments) String(indexOrNil ...int) string {
return strings.Join(argsStr, ",")
} else if len(indexOrNil) == 1 {
// Index has been specified - get the argument at that index
- var index = indexOrNil[0]
+ index := indexOrNil[0]
var s string
var ok bool
if s, ok = args.Get(index).(string); !ok {
@@ -897,7 +988,6 @@ func (args Arguments) String(indexOrNil ...int) string {
}
panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String. Must be 0 or 1, not %d", len(indexOrNil)))
-
}
// Int gets the argument at the specified index. Panics if there is no argument, or