package mock import ( "fmt" "reflect" "regexp" "runtime" "strings" "sync" "time" "github.com/stretchr/objx" "github.com/stretchr/testify/assert" ) // TestingT is an interface wrapper around *testing.T type TestingT interface { Logf(format string, args ...interface{}) Errorf(format string, args ...interface{}) FailNow() } /* Call */ // Call represents a method call and is used for setting expectations, // as well as recording activity. type Call struct { Parent *Mock // The name of the method that was or will be called. Method string // Holds the arguments of the method. Arguments Arguments // Holds the arguments that should be returned when // this method is called. ReturnArguments Arguments // The number of times to return the return arguments when setting // expectations. 0 means to always return the value. Repeatability int // Amount of times this call has been called totalCalls int // Holds a channel that will be used to block the Return until it either // recieves a message or is closed. nil means it returns immediately. WaitFor <-chan time.Time // Holds a handler used to manipulate arguments content that are passed by // reference. It's useful when mocking methods such as unmarshalers or // decoders. RunFn func(Arguments) } func newCall(parent *Mock, methodName string, methodArguments ...interface{}) *Call { return &Call{ Parent: parent, Method: methodName, Arguments: methodArguments, ReturnArguments: make([]interface{}, 0), Repeatability: 0, WaitFor: nil, RunFn: nil, } } func (c *Call) lock() { c.Parent.mutex.Lock() } func (c *Call) unlock() { c.Parent.mutex.Unlock() } // Return specifies the return arguments for the expectation. // // Mock.On("DoSomething").Return(errors.New("failed")) func (c *Call) Return(returnArguments ...interface{}) *Call { c.lock() defer c.unlock() c.ReturnArguments = returnArguments return c } // Once indicates that that the mock should only return the value once. // // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once() func (c *Call) Once() *Call { return c.Times(1) } // Twice indicates that that the mock should only return the value twice. // // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice() func (c *Call) Twice() *Call { return c.Times(2) } // Times indicates that that the mock should only return the indicated number // of times. // // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5) func (c *Call) Times(i int) *Call { c.lock() defer c.unlock() c.Repeatability = i return c } // WaitUntil sets the channel that will block the mock's return until its closed // or a message is received. // // Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second)) func (c *Call) WaitUntil(w <-chan time.Time) *Call { c.lock() defer c.unlock() c.WaitFor = w return c } // After sets how long to block until the call returns // // Mock.On("MyMethod", arg1, arg2).After(time.Second) func (c *Call) After(d time.Duration) *Call { return c.WaitUntil(time.After(d)) } // Run sets a handler to be called before returning. It can be used when // mocking a method such as unmarshalers that takes a pointer to a struct and // sets properties in such struct // // Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}").Return().Run(func(args Arguments) { // arg := args.Get(0).(*map[string]interface{}) // arg["foo"] = "bar" // }) func (c *Call) Run(fn func(Arguments)) *Call { c.lock() defer c.unlock() c.RunFn = fn return c } // On chains a new expectation description onto the mocked interface. This // allows syntax like. // // Mock. // On("MyMethod", 1).Return(nil). // On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error")) func (c *Call) On(methodName string, arguments ...interface{}) *Call { return c.Parent.On(methodName, arguments...) } // Mock is the workhorse used to track activity on another object. // For an example of its usage, refer to the "Example Usage" section at the top // of this document. type Mock struct { // Represents the calls that are expected of // an object. ExpectedCalls []*Call // Holds the calls that were made to this mocked object. Calls []Call // TestData holds any data that might be useful for testing. Testify ignores // this data completely allowing you to do whatever you like with it. testData objx.Map mutex sync.Mutex } // TestData holds any data that might be useful for testing. Testify ignores // this data completely allowing you to do whatever you like with it. func (m *Mock) TestData() objx.Map { if m.testData == nil { m.testData = make(objx.Map) } return m.testData } /* Setting expectations */ // On starts a description of an expectation of the specified method // being called. // // Mock.On("MyMethod", arg1, arg2) func (m *Mock) On(methodName string, arguments ...interface{}) *Call { for _, arg := range arguments { if v := reflect.ValueOf(arg); v.Kind() == reflect.Func { panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg)) } } m.mutex.Lock() defer m.mutex.Unlock() c := newCall(m, methodName, arguments...) m.ExpectedCalls = append(m.ExpectedCalls, c) return c } // /* // Recording and responding to activity // */ func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) { m.mutex.Lock() defer m.mutex.Unlock() for i, call := range m.ExpectedCalls { if call.Method == method && call.Repeatability > -1 { _, diffCount := call.Arguments.Diff(arguments) if diffCount == 0 { return i, call } } } return -1, nil } func (m *Mock) findClosestCall(method string, arguments ...interface{}) (bool, *Call) { diffCount := 0 var closestCall *Call for _, call := range m.expectedCalls() { if call.Method == method { _, tempDiffCount := call.Arguments.Diff(arguments) if tempDiffCount < diffCount || diffCount == 0 { diffCount = tempDiffCount closestCall = call } } } if closestCall == nil { return false, nil } return true, closestCall } func callString(method string, arguments Arguments, includeArgumentValues bool) string { var argValsString string if includeArgumentValues { var argVals []string for argIndex, arg := range arguments { argVals = append(argVals, fmt.Sprintf("%d: %#v", argIndex, arg)) } argValsString = fmt.Sprintf("\n\t\t%s", strings.Join(argVals, "\n\t\t")) } return fmt.Sprintf("%s(%s)%s", method, arguments.String(), argValsString) } // Called tells the mock object that a method has been called, and gets an array // of arguments to return. Panics if the call is unexpected (i.e. not preceded by // appropriate .On .Return() calls) // If Call.WaitFor is set, blocks until the channel is closed or receives a message. func (m *Mock) Called(arguments ...interface{}) Arguments { // get the calling function's name pc, _, _, ok := runtime.Caller(1) if !ok { panic("Couldn't get the caller information") } functionPath := runtime.FuncForPC(pc).Name() //Next four lines are required to use GCCGO function naming conventions. //For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock //uses inteface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree //With GCCGO we need to remove interface information starting from pN
. re := regexp.MustCompile("\\.pN\\d+_") if re.MatchString(functionPath) { functionPath = re.Split(functionPath, -1)[0] } parts := strings.Split(functionPath, ".") functionName := parts[len(parts)-1] found, call := m.findExpectedCall(functionName, arguments...) if found < 0 { // we have to fail here - because we don't know what to do // as the return arguments. This is because: // // a) this is a totally unexpected call to this method, // b) the arguments are not what was expected, or // c) the developer has forgotten to add an accompanying On...Return pair. closestFound, closestCall := m.findClosestCall(functionName, arguments...) if closestFound { panic(fmt.Sprintf("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n", callString(functionName, arguments, true), callString(functionName, closestCall.Arguments, true))) } else { panic(fmt.Sprintf("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", functionName, functionName, callString(functionName, arguments, true), assert.CallerInfo())) } } else { m.mutex.Lock() switch { case call.Repeatability == 1: call.Repeatability = -1 call.totalCalls++ case call.Repeatability > 1: call.Repeatability-- call.totalCalls++ case call.Repeatability == 0: call.totalCalls++ } m.mutex.Unlock() } // add the call m.mutex.Lock() m.Calls = append(m.Calls, *newCall(m, functionName, arguments...)) m.mutex.Unlock() // block if specified if call.WaitFor != nil { <-call.WaitFor } if call.RunFn != nil { call.RunFn(arguments) } return call.ReturnArguments } /* Assertions */ // AssertExpectationsForObjects asserts that everything specified with On and Return // of the specified objects was in fact called as expected. // // Calls may have occurred in any order. func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool { var success = true for _, obj := range testObjects { mockObj := obj.(Mock) success = success && mockObj.AssertExpectations(t) } return success } // AssertExpectations asserts that everything specified with On and Return was // in fact called as expected. Calls may have occurred in any order. func (m *Mock) AssertExpectations(t TestingT) bool { var somethingMissing bool var failedExpectations int // iterate through each expectation expectedCalls := m.expectedCalls() for _, expectedCall := range expectedCalls { if !m.methodWasCalled(expectedCall.Method, expectedCall.Arguments) && expectedCall.totalCalls == 0 { somethingMissing = true failedExpectations++ t.Logf("\u274C\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String()) } else { m.mutex.Lock() if expectedCall.Repeatability > 0 { somethingMissing = true failedExpectations++ } else { t.Logf("\u2705\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String()) } m.mutex.Unlock() } } if somethingMissing { t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo()) } return !somethingMissing } // AssertNumberOfCalls asserts that the method was called expectedCalls times. func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool { var actualCalls int for _, call := range m.calls() { if call.Method == methodName { actualCalls++ } } return assert.Equal(t, expectedCalls, actualCalls, fmt.Sprintf("Expected number of calls (%d) does not match the actual number of calls (%d).", expectedCalls, actualCalls)) } // AssertCalled asserts that the method was called. // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method. func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool { if !assert.True(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method should have been called with %d argument(s), but was not.", methodName, len(arguments))) { t.Logf("%v", m.expectedCalls()) return false } return true } // AssertNotCalled asserts that the method was not called. // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method. func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool { if !assert.False(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method was called with %d argument(s), but should NOT have been.", methodName, len(arguments))) { t.Logf("%v", m.expectedCalls()) return false } return true } func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool { for _, call := range m.calls() { if call.Method == methodName { _, differences := Arguments(expected).Diff(call.Arguments) if differences == 0 { // found the expected call return true } } } // we didn't find the expected call return false } func (m *Mock) expectedCalls() []*Call { m.mutex.Lock() defer m.mutex.Unlock() return append([]*Call{}, m.ExpectedCalls...) } func (m *Mock) calls() []Call { m.mutex.Lock() defer m.mutex.Unlock() return append([]Call{}, m.Calls...) } /* Arguments */ // Arguments holds an array of method arguments or return values. type Arguments []interface{} const ( // Anything is used in Diff and Assert when the argument being tested // shouldn't be taken into consideration. Anything string = "mock.Anything" ) // AnythingOfTypeArgument is a string that contains the type of an argument // for use when type checking. Used in Diff and Assert. type AnythingOfTypeArgument string // AnythingOfType returns an AnythingOfTypeArgument object containing the // name of the type to check for. Used in Diff and Assert. // // For example: // Assert(t, AnythingOfType("string"), AnythingOfType("int")) func AnythingOfType(t string) AnythingOfTypeArgument { return AnythingOfTypeArgument(t) } // argumentMatcher performs custom argument matching, returning whether or // not the argument is matched by the expectation fixture function. type argumentMatcher struct { // fn is a function which accepts one argument, and returns a bool. fn reflect.Value } func (f argumentMatcher) Matches(argument interface{}) bool { expectType := f.fn.Type().In(0) if reflect.TypeOf(argument).AssignableTo(expectType) { result := f.fn.Call([]reflect.Value{reflect.ValueOf(argument)}) return result[0].Bool() } return false } func (f argumentMatcher) String() string { return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).Name()) } // MatchedBy can be used to match a mock call based on only certain properties // from a complex struct or some calculation. It takes a function that will be // evaluated with the called argument and will return true when there's a match // and false otherwise. // // Example: // m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" })) // // |fn|, must be a function accepting a single argument (of the expected type) // which returns a bool. If |fn| doesn't match the required signature, // MathedBy() panics. func MatchedBy(fn interface{}) argumentMatcher { fnType := reflect.TypeOf(fn) if fnType.Kind() != reflect.Func { panic(fmt.Sprintf("assert: arguments: %s is not a func", fn)) } if fnType.NumIn() != 1 { panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn)) } if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool { panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn)) } return argumentMatcher{fn: reflect.ValueOf(fn)} } // Get Returns the argument at the specified index. func (args Arguments) Get(index int) interface{} { if index+1 > len(args) { panic(fmt.Sprintf("assert: arguments: Cannot call Get(%d) because there are %d argument(s).", index, len(args))) } return args[index] } // Is gets whether the objects match the arguments specified. func (args Arguments) Is(objects ...interface{}) bool { for i, obj := range args { if obj != objects[i] { return false } } return true } // Diff gets a string describing the differences between the arguments // and the specified objects. // // Returns the diff string and number of differences found. func (args Arguments) Diff(objects []interface{}) (string, int) { var output = "\n" var differences int var maxArgCount = len(args) if len(objects) > maxArgCount { maxArgCount = len(objects) } for i := 0; i < maxArgCount; i++ { var actual, expected interface{} if len(objects) <= i { actual = "(Missing)" } else { actual = objects[i] } if len(args) <= i { expected = "(Missing)" } else { expected = args[i] } if matcher, ok := expected.(argumentMatcher); ok { if matcher.Matches(actual) { output = fmt.Sprintf("%s\t%d: \u2705 %s matched by %s\n", output, i, actual, matcher) } else { differences++ output = fmt.Sprintf("%s\t%d: \u2705 %s not matched by %s\n", output, i, actual, matcher) } } else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() { // type checking if reflect.TypeOf(actual).Name() != string(expected.(AnythingOfTypeArgument)) && reflect.TypeOf(actual).String() != string(expected.(AnythingOfTypeArgument)) { // not match differences++ output = fmt.Sprintf("%s\t%d: \u274C type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actual) } } else { // normal checking if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) { // match output = fmt.Sprintf("%s\t%d: \u2705 %s == %s\n", output, i, actual, expected) } else { // not match differences++ output = fmt.Sprintf("%s\t%d: \u274C %s != %s\n", output, i, actual, expected) } } } if differences == 0 { return "No differences.", differences } return output, differences } // Assert compares the arguments with the specified objects and fails if // they do not exactly match. func (args Arguments) Assert(t TestingT, objects ...interface{}) bool { // get the differences diff, diffCount := args.Diff(objects) if diffCount == 0 { return true } // there are differences... report them... t.Logf(diff) t.Errorf("%sArguments do not match.", assert.CallerInfo()) return false } // String gets the argument at the specified index. Panics if there is no argument, or // if the argument is of the wrong type. // // If no index is provided, String() returns a complete string representation // of the arguments. func (args Arguments) String(indexOrNil ...int) string { if len(indexOrNil) == 0 { // normal String() method - return a string representation of the args var argsStr []string for _, arg := range args { argsStr = append(argsStr, fmt.Sprintf("%s", reflect.TypeOf(arg))) } return strings.Join(argsStr, ",") } else if len(indexOrNil) == 1 { // Index has been specified - get the argument at that index var index = indexOrNil[0] var s string var ok bool if s, ok = args.Get(index).(string); !ok { panic(fmt.Sprintf("assert: arguments: String(%d) failed because object wasn't correct type: %s", index, args.Get(index))) } return s } panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String. Must be 0 or 1, not %d", len(indexOrNil))) } // Int gets the argument at the specified index. Panics if there is no argument, or // if the argument is of the wrong type. func (args Arguments) Int(index int) int { var s int var ok bool if s, ok = args.Get(index).(int); !ok { panic(fmt.Sprintf("assert: arguments: Int(%d) failed because object wasn't correct type: %v", index, args.Get(index))) } return s } // Error gets the argument at the specified index. Panics if there is no argument, or // if the argument is of the wrong type. func (args Arguments) Error(index int) error { obj := args.Get(index) var s error var ok bool if obj == nil { return nil } if s, ok = obj.(error); !ok { panic(fmt.Sprintf("assert: arguments: Error(%d) failed because object wasn't correct type: %v", index, args.Get(index))) } return s } // Bool gets the argument at the specified index. Panics if there is no argument, or // if the argument is of the wrong type. func (args Arguments) Bool(index int) bool { var s bool var ok bool if s, ok = args.Get(index).(bool); !ok { panic(fmt.Sprintf("assert: arguments: Bool(%d) failed because object wasn't correct type: %v", index, args.Get(index))) } return s }