Created
September 25, 2023 18:37
-
-
Save akshatvishu/6d859b5f8c55e6a24995261e9abb5785 to your computer and use it in GitHub Desktop.
try to mimic paddle.scale
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def custom_scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): | |
| original_dtype = x.dtype | |
| if original_dtype in [paddle.int16,paddle.int8, paddle.uint8]: | |
| x = paddle.cast(x, dtype=paddle.int32) | |
| # ToDO:May need to add logic when scale_input is tensor | |
| # because `x` dtype and `scale` dtype need to match! | |
| if not isinstance(scale, paddle.Tensor): | |
| scale = paddle.to_tensor(scale, dtype=x.dtype) | |
| if not isinstance(bias, paddle.Tensor): | |
| bias = paddle.to_tensor(bias, dtype=x.dtype) | |
| # Convert dtype to int32 for calculations if input dtype is int8 or uint8 | |
| # because add or multiply API don't support them at-least for device=CPU | |
| if bias_after_scale: | |
| out = paddle.add(paddle.multiply(x, scale), bias) | |
| else: | |
| out = paddle.multiply(paddle.add(x, bias), scale) | |
| # Truncate to the nearest integer for integer types before casting back | |
| # if "int" in str(original_dtype): | |
| # out = paddle.trunc(out) | |
| # Turns out this dosen't work, CurseYou Paddlefokls!! | |
| # Cast back to the original dtype | |
| out = paddle.cast(out, dtype=original_dtype) | |
| # Apply activation function if provided and allowed | |
| if act is not None: | |
| if original_dtype not in [paddle.float32, paddle.float64]: | |
| print("Error: For CPU, only float32 and float64 are allowed for act =", act) | |
| else: | |
| # These where the ones mentioned in docs for paddle.scale | |
| # Reference: https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/scale_en.html#scale | |
| allowed_acts = ['tanh', 'softmax', 'sigmoid', 'relu'] | |
| if act in allowed_acts: | |
| if hasattr(paddle.nn.functional, act): | |
| out = getattr(paddle.nn.functional, act)(out) | |
| elif hasattr(paddle, act): | |
| out = getattr(paddle, act)(out) | |
| else: | |
| print("Error: Unsupported activation function. Supported functions are:", allowed_acts) | |
| else: | |
| print("Error: Unsupported activation function. Supported functions are:", allowed_acts) | |
| return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment