Skip to content

Instantly share code, notes, and snippets.

@akshatvishu
Created September 25, 2023 18:37
Show Gist options
  • Save akshatvishu/6d859b5f8c55e6a24995261e9abb5785 to your computer and use it in GitHub Desktop.
Save akshatvishu/6d859b5f8c55e6a24995261e9abb5785 to your computer and use it in GitHub Desktop.
try to mimic paddle.scale
def custom_scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
original_dtype = x.dtype
if original_dtype in [paddle.int16,paddle.int8, paddle.uint8]:
x = paddle.cast(x, dtype=paddle.int32)
# ToDO:May need to add logic when scale_input is tensor
# because `x` dtype and `scale` dtype need to match!
if not isinstance(scale, paddle.Tensor):
scale = paddle.to_tensor(scale, dtype=x.dtype)
if not isinstance(bias, paddle.Tensor):
bias = paddle.to_tensor(bias, dtype=x.dtype)
# Convert dtype to int32 for calculations if input dtype is int8 or uint8
# because add or multiply API don't support them at-least for device=CPU
if bias_after_scale:
out = paddle.add(paddle.multiply(x, scale), bias)
else:
out = paddle.multiply(paddle.add(x, bias), scale)
# Truncate to the nearest integer for integer types before casting back
# if "int" in str(original_dtype):
# out = paddle.trunc(out)
# Turns out this dosen't work, CurseYou Paddlefokls!!
# Cast back to the original dtype
out = paddle.cast(out, dtype=original_dtype)
# Apply activation function if provided and allowed
if act is not None:
if original_dtype not in [paddle.float32, paddle.float64]:
print("Error: For CPU, only float32 and float64 are allowed for act =", act)
else:
# These where the ones mentioned in docs for paddle.scale
# Reference: https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/scale_en.html#scale
allowed_acts = ['tanh', 'softmax', 'sigmoid', 'relu']
if act in allowed_acts:
if hasattr(paddle.nn.functional, act):
out = getattr(paddle.nn.functional, act)(out)
elif hasattr(paddle, act):
out = getattr(paddle, act)(out)
else:
print("Error: Unsupported activation function. Supported functions are:", allowed_acts)
else:
print("Error: Unsupported activation function. Supported functions are:", allowed_acts)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment