import os import time import torch import torch_xla.core.xla_model as xm N = 16 def main(): # os.environ["XLA_USE_BF16"] = "1" os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "1" os.environ["NEURON_CC_FLAGS"] = " --internal-hlo2tensorizer-options=--experimental-unsafe-fp8e4m3fn-as-fp8e4m3 --execute-repetition=1 " device = xm.xla_device() data = torch.arange(N).reshape(1,N).expand(32,N).to(device=device) * 2 print(f"{data=}") # output = (data / float(N)).to(dtype=torch.float8_e4m3fn) output_fp32 = (data / float(N)).to(dtype=torch.float32) * 3.11111 print(f"{output_fp32=}") output_bf16 = output_fp32.to(dtype=torch.bfloat16) print(f"{output_bf16=}") output_fp8e4m3 = output_fp32.to(dtype=torch.float8_e4m3fn) print(f"{output_fp8e4m3=}") if __name__=="__main__": main()