Skip to content

Instantly share code, notes, and snippets.

@Cyberax
Created July 26, 2021 04:52
Show Gist options
  • Select an option

  • Save Cyberax/eb42d249d022c55ce9dc6572309200ce to your computer and use it in GitHub Desktop.

Select an option

Save Cyberax/eb42d249d022c55ce9dc6572309200ce to your computer and use it in GitHub Desktop.

Revisions

  1. Cyberax created this gist Jul 26, 2021.
    158 changes: 158 additions & 0 deletions aws_mock.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,158 @@
    package utils

    import (
    "context"
    "github.com/aws/aws-sdk-go-v2/aws"
    awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
    "github.com/aws/smithy-go/middleware"
    "reflect"
    )

    type AwsMockHandler struct {
    handlers []reflect.Value
    functors []reflect.Value
    }

    // NewAwsMockHandler - Create an AWS mocker to use with the AWS services, it returns an instrumented
    // aws.Config that can be used to create AWS services.
    // You can add as many individual request handlers as you need, as long as handlers
    // correspond to the func(context.Context, <arg>)(<res>, error) format.
    // E.g.:
    // func(context.Context, *ec2.TerminateInstancesInput)(*ec2.TerminateInstancesOutput, error)
    //
    // You can also use a struct as the handler, in this case the AwsMockHandler will try
    // to search for a method with a conforming signature.
    func NewAwsMockHandler() *AwsMockHandler {
    return &AwsMockHandler{}
    }

    type retargetingHandler struct {
    parent *AwsMockHandler
    }

    func (f *retargetingHandler) ID() string {
    return "ShortCircuitRequest"
    }

    type initialRequestKey struct{}

    func (f *retargetingHandler) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput,
    next middleware.DeserializeHandler) (out middleware.DeserializeOutput, metadata middleware.Metadata, err error) {

    req := ctx.Value(&initialRequestKey{})
    out.Result, err = f.parent.invokeMethod(ctx, req)
    return
    }

    type saveRequestMiddleware struct {
    }

    func (s saveRequestMiddleware) ID() string {
    return "OriginalRequestSaver"
    }

    func (s saveRequestMiddleware) HandleInitialize(ctx context.Context, in middleware.InitializeInput,
    next middleware.InitializeHandler) (out middleware.InitializeOutput, metadata middleware.Metadata, err error) {

    return next.HandleInitialize(context.WithValue(ctx, &initialRequestKey{}, in.Parameters), in)
    }

    func (a *AwsMockHandler) AwsConfig() aws.Config {
    cfg := aws.NewConfig()
    cfg.Region = "us-mars-1"
    cfg.APIOptions = []func(*middleware.Stack) error{
    func(stack *middleware.Stack) error {
    // We leave the serialization middleware intact in the vain hope that
    // AWS re-adds validation to serialization.
    //stack.Initialize.Clear()
    //stack.Serialize.Clear()

    // Make sure to save the initial non-serialized request
    _ = stack.Initialize.Add(&saveRequestMiddleware{}, middleware.Before)

    // Clear all the other middleware
    stack.Build.Clear()
    stack.Finalize.Clear()
    stack.Deserialize.Clear()

    // And replace the last one with our special middleware that dispatches
    // the request to our handlers
    _ = stack.Deserialize.Add(&retargetingHandler{parent: a}, middleware.Before)
    return nil
    },
    }

    return *cfg
    }


    func (a *AwsMockHandler) AddHandler(handlerObject interface {}) {
    handler := reflect.ValueOf(handlerObject)
    tp := handler.Type()

    if handler.Kind() == reflect.Func {
    PanicIfF(tp.NumOut() != 2 || tp.NumIn() != 2,
    "handler must have signature of func(context.Context, <arg>)(<res>, error)")
    a.functors = append(a.functors, handler)
    } else {
    PanicIfF(tp.NumMethod() == 0, "the handler must have invokable methods")
    a.handlers = append(a.handlers, handler)
    }
    }

    func (a *AwsMockHandler) invokeMethod(ctx context.Context,
    params interface{}) (interface{}, error) {

    for _, h := range a.handlers {
    for i := 0; i < h.NumMethod(); i++ {
    method := h.Method(i)

    matched, res, err := tryInvoke(ctx, params, method)
    if matched {
    return res, err
    }
    }
    }

    for _, f := range a.functors {
    matched, res, err := tryInvoke(ctx, params, f)
    if matched {
    return res, err
    }
    }

    panic("could not find a handler for operation: " + awsmiddleware.GetOperationName(ctx))
    }

    func tryInvoke(ctx context.Context, params interface{}, method reflect.Value) (
    bool, interface{}, error) {

    paramType := reflect.TypeOf(params)
    errorType := reflect.TypeOf((*error)(nil)).Elem()
    contextType := reflect.TypeOf((*context.Context)(nil)).Elem()

    methodDesc := method.Type()
    if methodDesc.NumIn() != 2 || methodDesc.NumOut() != 2 {
    return false, nil, nil
    }

    if !contextType.ConvertibleTo(methodDesc.In(0)) {
    return false, nil, nil
    }
    if !paramType.ConvertibleTo(methodDesc.In(1)) {
    return false, nil, nil
    }
    if !methodDesc.Out(1).ConvertibleTo(errorType) {
    return false, nil, nil
    }

    // It's our target!
    res := method.Call([]reflect.Value{reflect.ValueOf(ctx),
    reflect.ValueOf(params)})

    if !res[1].IsNil() {
    return true, nil, res[1].Interface().(error)
    }

    return true, res[0].Interface(), nil
    }
    104 changes: 104 additions & 0 deletions aws_mock_test.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,104 @@
    package utils

    import (
    "context"
    "github.com/aws/aws-sdk-go-v2/aws"
    "github.com/aws/aws-sdk-go-v2/service/ec2"
    "github.com/aws/smithy-go"
    "github.com/stretchr/testify/assert"
    "testing"
    )

    type tester struct {
    }

    //noinspection GoUnusedParameter
    func (t *tester) TerminateInstances(ctx context.Context,
    input *ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error) {
    return nil, smithy.NewErrParamRequired("something")
    }

    //noinspection GoUnusedParameter
    func (t *tester) AlmostRunDescribe1(ctx context.Context,
    input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, string) {
    return nil, ""
    }

    //noinspection GoUnusedParameter
    func (t *tester) AlmostRunDescribe2(input *ec2.DescribeInstancesInput, _ string) (
    *ec2.DescribeInstancesOutput, error) {
    return nil, nil
    }

    //noinspection GoUnusedParameter
    func (t *tester) AlmostRunDescribe3(ctx context.Context,
    input *ec2.DescribeAccountAttributesInput) (*ec2.DescribeInstancesOutput, error) {
    return nil, nil
    }

    //noinspection GoUnusedParameter
    func (t *tester) AlmostRunDescribe4(input *ec2.DescribeInstancesInput) (
    *ec2.DescribeInstancesOutput, error) {
    return nil, nil
    }

    //noinspection GoUnusedParameter
    func (t *tester) AlmostRunDescribe5(ctx context.Context,
    input *ec2.DescribeInstancesInput) error {
    return nil
    }

    func TestMockNotFound(t *testing.T) {
    am := AwsMockHandler{}
    am.AddHandler(&tester{})

    assert.Panics(t, func() {
    ec := ec2.NewFromConfig(am.AwsConfig())
    _, _ = ec.DeleteKeyPair(context.Background(), &ec2.DeleteKeyPairInput{
    KeyName: aws.String("something"),
    })
    }, "could not find a handler for operation: DeleteKeyPair")
    }

    func TestAwsMock(t *testing.T) {
    am := NewAwsMockHandler()
    am.AddHandler(&tester{})
    am.AddHandler(func(ctx context.Context, arg *ec2.DescribeInstancesInput) (
    *ec2.DescribeInstancesOutput, error) {
    return &ec2.DescribeInstancesOutput{NextToken: arg.NextToken}, nil
    })
    am.AddHandler(func(ctx context.Context, arg *ec2.TerminateInstancesInput) (
    *ec2.DescribeInstancesOutput, error) {
    return nil, nil
    })

    ec := ec2.NewFromConfig(am.AwsConfig())

    response, e := ec.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{
    NextToken: aws.String("hello, token"),
    })
    assert.NoError(t, e)
    assert.Equal(t, "hello, token", *response.NextToken)

    // Check the tester methods
    _, err := ec.TerminateInstances(context.Background(), &ec2.TerminateInstancesInput{})
    assert.Error(t, err, "something")
    }

    func ExampleNewAwsMockHandler() {
    am := NewAwsMockHandler()
    am.AddHandler(func(ctx context.Context, arg *ec2.TerminateInstancesInput) (
    *ec2.TerminateInstancesOutput, error) {

    if arg.InstanceIds[0] != "i-123" {
    panic("BadInstanceId")
    }
    return &ec2.TerminateInstancesOutput{}, nil
    })

    ec := ec2.NewFromConfig(am.AwsConfig())

    _, _ = ec.TerminateInstances(context.Background(), &ec2.TerminateInstancesInput{
    InstanceIds: []string{"i-123"},
    })
    }