Skip to content

Instantly share code, notes, and snippets.

@mrdulin
Forked from feketegy/mockmetrics.go
Created May 25, 2020 12:45
Show Gist options
  • Save mrdulin/8d8e173b77668c99ad784fc24fa21f28 to your computer and use it in GitHub Desktop.
Save mrdulin/8d8e173b77668c99ad784fc24fa21f28 to your computer and use it in GitHub Desktop.

Revisions

  1. @feketegy feketegy created this gist Mar 9, 2020.
    233 changes: 233 additions & 0 deletions mockmetrics.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,233 @@
    // Package mockmetrics will gather metrics on a mocked function, such as number of calls, record arguments that the
    // mocked function is called with.
    //
    // Example:
    //
    // type NumberGetter interface {
    // GetNumber() int
    // AddNumber(num int)
    // }
    //
    // type MockedNumber struct {
    // Nr int
    // mockmetrics.Spy
    // }
    //
    // func (p *MockedNumber) GetNumber() int {
    // p.Called()
    // return p.Nr
    // }
    //
    // func (p *MockedNumber) AddNumber(num int) {
    // p.Called(num)
    // p.Nr += 1
    // }
    //
    // func main() {
    // mock := MockerNumber{
    // Nr: 15,
    // }
    //
    // myNr := mock.GetNumber()
    // fmt.Printf("\n%d", myNr)
    // mock.AddNumber(2)
    // myNewNr := mock.GetNumber()
    // fmt.Printf("\n%d", myNewNr)
    //
    // call1, err1 := mock.GetCall("GetNumber", 1)
    // call2, err2 := mock.GetCall("GetNumber", 2)
    // call3, err3 := mock.GetCall("AddNumber", 1)
    //
    // fmt.Printf("\n GetNumber Nr Calls: %d", mock.NrCalls("GetNumber"))
    // fmt.Printf("\n GetNumber Call 1 At: %+v -- err: %+v", call1.CalledAt(), err1)
    // fmt.Printf("\n GetNumber Call 2 At: %+v -- err: %+v", call2.CalledAt(), err2)
    // fmt.Printf("\n AddNumber Nr Calls: %d", mock.NrCalls("AddNumber"))
    // fmt.Printf("\n AddNumber Call 1 Called With: %+v -- At: %+v -- err: %+v", call3.CalledWith(), call3.CalledAt(), err3)
    //
    // }
    //
    package mockmetrics

    import (
    "errors"
    "regexp"
    "runtime"
    "strings"
    "sync"
    "time"
    )

    var (
    errNotFound = errors.New("Not Found.")
    )

    // call represents a single call to a mocked function.
    type call struct {
    spy *Spy
    at time.Time
    calledWith []interface{}
    }

    // callMetrics represents all the calls to a mocked function.
    type callMetrics struct {
    spy *Spy
    calls []call
    }

    // Spy is the composable structure that needs to be added to the mocked structs.
    //
    // Example:
    //
    // type MyMock struct {
    // mockmetrics.Spy
    // }
    //
    // This way MyMock will be composed with Spy struct's methods and gather info on MyMock.
    type Spy struct {
    calls map[string]*callMetrics
    mtx sync.Mutex
    }

    // Called will record info on the method called.
    //
    // Example:
    //
    // type MyMock struct {
    // mockmetrics.Spy
    // }
    //
    // func (p *MyMock) FuncToSatisfyInterface(arg1 int, arg2 string) (err error) {
    // p.Called(arg1, arg2)
    // }
    //
    // It's up to the implementer to invoke Called and supply the arguments.
    func (p *Spy) Called(args ...interface{}) {

    p.mtx.Lock()
    defer p.mtx.Unlock()

    pc, _, _, ok := runtime.Caller(1)
    if !ok {
    panic("Coudn't get the called func information.")
    }

    funcPath := runtime.FuncForPC(pc).Name()
    funcName := getFuncName(funcPath)

    if len(p.calls) == 0 {
    p.calls = make(map[string]*callMetrics)
    }

    if p.calls[funcName] == nil {
    p.calls[funcName] = &callMetrics{
    spy: p,
    }
    }

    p.calls[funcName].calls = append(p.calls[funcName].calls, call{
    spy: p,
    at: time.Now(),
    calledWith: args,
    })
    }

    // GetCall will get the info on the called funcName and index representing the call count.
    // The call count index starts from 1.
    //
    // Example:
    //
    // mock := MyMock{}
    // mock.DoSomething()
    // call, err := mock.GetCall("DoSomething", 1)
    //
    func (p *Spy) GetCall(funcName string, index int) (c call, err error) {

    p.mtx.Lock()
    defer p.mtx.Unlock()

    if index < 1 {
    return c, errNotFound
    }

    cm, ok := p.calls[funcName]
    if !ok {
    return c, errNotFound
    }

    if len(cm.calls) <= index-1 {
    return c, errNotFound
    }

    c = cm.calls[index-1]

    return
    }

    // NrCalls will return the total calls a mocked funcName received.
    //
    // Example:
    //
    // mock := MyMock{}
    // mock.DoSomething()
    // totalCalls := mock.NrCalls("DoSomething")
    //
    func (p *Spy) NrCalls(funcName string) int {
    p.mtx.Lock()
    defer p.mtx.Unlock()

    cm, ok := p.calls[funcName]
    if !ok {
    return 0
    }

    return len(cm.calls)
    }

    // CalledWith will return the argument values the function was called with.
    // It's up to the caller to cast each value to the appropiate type for further comparisons.
    //
    // Example:
    //
    // mock := MyMock{}
    // mock.DoSomething("abc", 123)
    // mock.DoSomething("def", 456)
    // call1, err := mock.GetCall("DoSomething", 1)
    // call2, err := mock.GetCall("DoSomething", 2)
    // args1 := call1.CalledWith() // args1 = ["abc", 123]
    // args2 := call2.CalledWith() // args2 = ["def", 456]
    //
    func (p *call) CalledWith() []interface{} {

    p.spy.mtx.Lock()
    defer p.spy.mtx.Unlock()

    return p.calledWith
    }

    // CalledAt will return the time the mock function was called.
    //
    // Example:
    //
    // mock := MyMock{}
    // mock.DoSomething()
    // call1, err := mock.GetCall("DoSomething", 1)
    // t := call1.CalledAt()
    //
    func (p *call) CalledAt() time.Time {

    p.spy.mtx.Lock()
    defer p.spy.mtx.Unlock()

    return p.at
    }

    // getFuncName will return the runtime function name based on the PC function path.
    func getFuncName(funcPath string) string {
    re := regexp.MustCompile("\\.pN\\d+_")
    if re.MatchString(funcPath) {
    funcPath = re.Split(funcPath, -1)[0]
    }

    parts := strings.Split(funcPath, ".")
    return parts[len(parts)-1]
    }