Created
September 18, 2024 00:14
-
-
Save cli99/b6e4fff1a607258c59678de57f37af52 to your computer and use it in GitHub Desktop.
torch.compile
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import timeit | |
| import torch | |
| @torch.compile() # 0.103 seconds | |
| # @torch.compile(fullgraph=True) # 0.105 seconds | |
| # @torch.compile(fullgraph=False) # 0.102 seconds | |
| # @torch.compile(options={"triton.cudagraphs": False}, fullgraph=True) # 0.104 seconds | |
| # @torch.compile( | |
| # backend="cudagraphs", fullgraph=True | |
| # ) # 0.155 seconds with torch.compiler.cudagraph_mark_step_begin() | |
| # @torch.compile(backend="onnxrt", fullgraph=True) # 0.012 seconds | |
| def silu_and_mul(x1: torch.Tensor, x2: torch.Tensor, cuda_sync=True) -> torch.Tensor: | |
| ret = torch.nn.functional.silu(x1) * x2 | |
| if cuda_sync: | |
| torch.cuda.synchronize() | |
| return ret | |
| # print(torch._dynamo.list_backends()) | |
| # ['cudagraphs', 'inductor', 'onnxrt', 'openxla', 'tvm'] | |
| bsq = 128 | |
| dim = 8192 | |
| x1 = torch.randn(bsq, dim, device="cuda") | |
| x2 = torch.randn(bsq, dim, device="cuda") | |
| # 128 * 8192 * 4 = 1048576 * 4 = 4 MB | |
| # est_time = 1000 * 4 MB / 3.35TB/s = 1.139 ms ??? | |
| # warmup | |
| for _ in range(3): | |
| out = silu_and_mul(x1, x2) | |
| with torch.no_grad(): | |
| # torch.compiler.cudagraph_mark_step_begin() | |
| t = timeit.timeit(lambda: silu_and_mul(x1, x2), number=1000) | |
| print(f"Execution time: {t:.3f} seconds") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment