aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/stretchr/testify/mock
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2020-09-15 18:05:35 +0200
committerDmitry Vyukov <dvyukov@google.com>2020-09-15 19:34:30 +0200
commit712de1c63d9db97c81af68cd0dc4372c53d2e57a (patch)
treeae1761fec52c3ae4ddd003a4130ddbda8d0a2d69 /vendor/github.com/stretchr/testify/mock
parent298a69c38dd5c8a9bbd7a022e88f4ddbcf885e16 (diff)
vendor/github.com/golangci/golangci-lint: update to v1.31
Diffstat (limited to 'vendor/github.com/stretchr/testify/mock')
-rw-r--r--vendor/github.com/stretchr/testify/mock/mock.go68
1 files changed, 66 insertions, 2 deletions
diff --git a/vendor/github.com/stretchr/testify/mock/mock.go b/vendor/github.com/stretchr/testify/mock/mock.go
index 58e0798da..c6df4485a 100644
--- a/vendor/github.com/stretchr/testify/mock/mock.go
+++ b/vendor/github.com/stretchr/testify/mock/mock.go
@@ -65,6 +65,11 @@ type Call struct {
// reference. It's useful when mocking methods such as unmarshalers or
// decoders.
RunFn func(Arguments)
+
+ // PanicMsg holds msg to be used to mock panic on the function call
+ // if the PanicMsg is set to a non nil string the function call will panic
+ // irrespective of other settings
+ PanicMsg *string
}
func newCall(parent *Mock, methodName string, callerInfo []string, methodArguments ...interface{}) *Call {
@@ -77,6 +82,7 @@ func newCall(parent *Mock, methodName string, callerInfo []string, methodArgumen
Repeatability: 0,
WaitFor: nil,
RunFn: nil,
+ PanicMsg: nil,
}
}
@@ -100,6 +106,18 @@ func (c *Call) Return(returnArguments ...interface{}) *Call {
return c
}
+// Panic specifies if the functon call should fail and the panic message
+//
+// Mock.On("DoSomething").Panic("test panic")
+func (c *Call) Panic(msg string) *Call {
+ c.lock()
+ defer c.unlock()
+
+ c.PanicMsg = &msg
+
+ return c
+}
+
// Once indicates that that the mock should only return the value once.
//
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once()
@@ -150,7 +168,7 @@ func (c *Call) After(d time.Duration) *Call {
// mocking a method (such as an unmarshaler) that takes a pointer to a struct and
// sets properties in such struct
//
-// Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}").Return().Run(func(args Arguments) {
+// Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}")).Return().Run(func(args Arguments) {
// arg := args.Get(0).(*map[string]interface{})
// arg["foo"] = "bar"
// })
@@ -393,6 +411,13 @@ func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Argumen
}
m.mutex.Lock()
+ panicMsg := call.PanicMsg
+ m.mutex.Unlock()
+ if panicMsg != nil {
+ panic(*panicMsg)
+ }
+
+ m.mutex.Lock()
runFn := call.RunFn
m.mutex.Unlock()
@@ -527,6 +552,45 @@ func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...inter
return true
}
+// IsMethodCallable checking that the method can be called
+// If the method was called more than `Repeatability` return false
+func (m *Mock) IsMethodCallable(t TestingT, methodName string, arguments ...interface{}) bool {
+ if h, ok := t.(tHelper); ok {
+ h.Helper()
+ }
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+
+ for _, v := range m.ExpectedCalls {
+ if v.Method != methodName {
+ continue
+ }
+ if len(arguments) != len(v.Arguments) {
+ continue
+ }
+ if v.Repeatability < v.totalCalls {
+ continue
+ }
+ if isArgsEqual(v.Arguments, arguments) {
+ return true
+ }
+ }
+ return false
+}
+
+// isArgsEqual compares arguments
+func isArgsEqual(expected Arguments, args []interface{}) bool {
+ if len(expected) != len(args) {
+ return false
+ }
+ for i, v := range args {
+ if !reflect.DeepEqual(expected[i], v) {
+ return false
+ }
+ }
+ return true
+}
+
func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool {
for _, call := range m.calls() {
if call.Method == methodName {
@@ -791,7 +855,7 @@ func (args Arguments) String(indexOrNil ...int) string {
// normal String() method - return a string representation of the args
var argsStr []string
for _, arg := range args {
- argsStr = append(argsStr, fmt.Sprintf("%s", reflect.TypeOf(arg)))
+ argsStr = append(argsStr, fmt.Sprintf("%T", arg)) // handles nil nicely
}
return strings.Join(argsStr, ",")
} else if len(indexOrNil) == 1 {