Skip to content

Instantly share code, notes, and snippets.

@akshatvishu
Created September 25, 2023 18:37
Show Gist options
  • Save akshatvishu/6d859b5f8c55e6a24995261e9abb5785 to your computer and use it in GitHub Desktop.
Save akshatvishu/6d859b5f8c55e6a24995261e9abb5785 to your computer and use it in GitHub Desktop.

Revisions

  1. akshatvishu created this gist Sep 25, 2023.
    50 changes: 50 additions & 0 deletions my_paddle_scale.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,50 @@
    def custom_scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
    original_dtype = x.dtype
    if original_dtype in [paddle.int16,paddle.int8, paddle.uint8]:
    x = paddle.cast(x, dtype=paddle.int32)

    # ToDO:May need to add logic when scale_input is tensor
    # because `x` dtype and `scale` dtype need to match!
    if not isinstance(scale, paddle.Tensor):
    scale = paddle.to_tensor(scale, dtype=x.dtype)

    if not isinstance(bias, paddle.Tensor):
    bias = paddle.to_tensor(bias, dtype=x.dtype)

    # Convert dtype to int32 for calculations if input dtype is int8 or uint8
    # because add or multiply API don't support them at-least for device=CPU


    if bias_after_scale:
    out = paddle.add(paddle.multiply(x, scale), bias)
    else:
    out = paddle.multiply(paddle.add(x, bias), scale)

    # Truncate to the nearest integer for integer types before casting back
    # if "int" in str(original_dtype):
    # out = paddle.trunc(out)
    # Turns out this dosen't work, CurseYou Paddlefokls!!

    # Cast back to the original dtype
    out = paddle.cast(out, dtype=original_dtype)


    # Apply activation function if provided and allowed
    if act is not None:
    if original_dtype not in [paddle.float32, paddle.float64]:
    print("Error: For CPU, only float32 and float64 are allowed for act =", act)
    else:
    # These where the ones mentioned in docs for paddle.scale
    # Reference: https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/scale_en.html#scale
    allowed_acts = ['tanh', 'softmax', 'sigmoid', 'relu']
    if act in allowed_acts:
    if hasattr(paddle.nn.functional, act):
    out = getattr(paddle.nn.functional, act)(out)
    elif hasattr(paddle, act):
    out = getattr(paddle, act)(out)
    else:
    print("Error: Unsupported activation function. Supported functions are:", allowed_acts)
    else:
    print("Error: Unsupported activation function. Supported functions are:", allowed_acts)

    return out