Created
September 25, 2023 18:37
-
-
Save akshatvishu/6d859b5f8c55e6a24995261e9abb5785 to your computer and use it in GitHub Desktop.
Revisions
-
akshatvishu created this gist
Sep 25, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,50 @@ def custom_scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): original_dtype = x.dtype if original_dtype in [paddle.int16,paddle.int8, paddle.uint8]: x = paddle.cast(x, dtype=paddle.int32) # ToDO:May need to add logic when scale_input is tensor # because `x` dtype and `scale` dtype need to match! if not isinstance(scale, paddle.Tensor): scale = paddle.to_tensor(scale, dtype=x.dtype) if not isinstance(bias, paddle.Tensor): bias = paddle.to_tensor(bias, dtype=x.dtype) # Convert dtype to int32 for calculations if input dtype is int8 or uint8 # because add or multiply API don't support them at-least for device=CPU if bias_after_scale: out = paddle.add(paddle.multiply(x, scale), bias) else: out = paddle.multiply(paddle.add(x, bias), scale) # Truncate to the nearest integer for integer types before casting back # if "int" in str(original_dtype): # out = paddle.trunc(out) # Turns out this dosen't work, CurseYou Paddlefokls!! # Cast back to the original dtype out = paddle.cast(out, dtype=original_dtype) # Apply activation function if provided and allowed if act is not None: if original_dtype not in [paddle.float32, paddle.float64]: print("Error: For CPU, only float32 and float64 are allowed for act =", act) else: # These where the ones mentioned in docs for paddle.scale # Reference: https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/scale_en.html#scale allowed_acts = ['tanh', 'softmax', 'sigmoid', 'relu'] if act in allowed_acts: if hasattr(paddle.nn.functional, act): out = getattr(paddle.nn.functional, act)(out) elif hasattr(paddle, act): out = getattr(paddle, act)(out) else: print("Error: Unsupported activation function. Supported functions are:", allowed_acts) else: print("Error: Unsupported activation function. Supported functions are:", allowed_acts) return out