import timeit import torch @torch.compile() # 0.103 seconds # @torch.compile(fullgraph=True) # 0.105 seconds # @torch.compile(fullgraph=False) # 0.102 seconds # @torch.compile(options={"triton.cudagraphs": False}, fullgraph=True) # 0.104 seconds # @torch.compile( # backend="cudagraphs", fullgraph=True # ) # 0.155 seconds with torch.compiler.cudagraph_mark_step_begin() # @torch.compile(backend="onnxrt", fullgraph=True) # 0.012 seconds def silu_and_mul(x1: torch.Tensor, x2: torch.Tensor, cuda_sync=True) -> torch.Tensor: ret = torch.nn.functional.silu(x1) * x2 if cuda_sync: torch.cuda.synchronize() return ret # print(torch._dynamo.list_backends()) # ['cudagraphs', 'inductor', 'onnxrt', 'openxla', 'tvm'] bsq = 128 dim = 8192 x1 = torch.randn(bsq, dim, device="cuda") x2 = torch.randn(bsq, dim, device="cuda") # 128 * 8192 * 4 = 1048576 * 4 = 4 MB # est_time = 1000 * 4 MB / 3.35TB/s = 1.139 ms ??? # warmup for _ in range(3): out = silu_and_mul(x1, x2) with torch.no_grad(): # torch.compiler.cudagraph_mark_step_begin() t = timeit.timeit(lambda: silu_and_mul(x1, x2), number=1000) print(f"Execution time: {t:.3f} seconds")