Skip to content

Instantly share code, notes, and snippets.

@cli99
Created September 18, 2024 00:14
Show Gist options
  • Save cli99/b6e4fff1a607258c59678de57f37af52 to your computer and use it in GitHub Desktop.
Save cli99/b6e4fff1a607258c59678de57f37af52 to your computer and use it in GitHub Desktop.
torch.compile
import timeit
import torch
@torch.compile() # 0.103 seconds
# @torch.compile(fullgraph=True) # 0.105 seconds
# @torch.compile(fullgraph=False) # 0.102 seconds
# @torch.compile(options={"triton.cudagraphs": False}, fullgraph=True) # 0.104 seconds
# @torch.compile(
# backend="cudagraphs", fullgraph=True
# ) # 0.155 seconds with torch.compiler.cudagraph_mark_step_begin()
# @torch.compile(backend="onnxrt", fullgraph=True) # 0.012 seconds
def silu_and_mul(x1: torch.Tensor, x2: torch.Tensor, cuda_sync=True) -> torch.Tensor:
ret = torch.nn.functional.silu(x1) * x2
if cuda_sync:
torch.cuda.synchronize()
return ret
# print(torch._dynamo.list_backends())
# ['cudagraphs', 'inductor', 'onnxrt', 'openxla', 'tvm']
bsq = 128
dim = 8192
x1 = torch.randn(bsq, dim, device="cuda")
x2 = torch.randn(bsq, dim, device="cuda")
# 128 * 8192 * 4 = 1048576 * 4 = 4 MB
# est_time = 1000 * 4 MB / 3.35TB/s = 1.139 ms ???
# warmup
for _ in range(3):
out = silu_and_mul(x1, x2)
with torch.no_grad():
# torch.compiler.cudagraph_mark_step_begin()
t = timeit.timeit(lambda: silu_and_mul(x1, x2), number=1000)
print(f"Execution time: {t:.3f} seconds")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment