-
-
Save yiliu30/f0acb7212ee82189aeb4e03f1bc7ad1f to your computer and use it in GitHub Desktop.
Revisions
-
Chillee revised this gist
Feb 26, 2024 . 1 changed file with 25 additions and 23 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -11,35 +11,37 @@ D = 2048 E = 8 for D in [1024, 2048, 4096, 8192, 16384]: def bench(f, name=None, iters=1000, warmup=5, display=True, profile=False): import time from triton.testing import do_bench for _ in range(warmup): f() if profile: with torch.profiler.profile() as prof: f() prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json") us_per_iter = do_bench(lambda: f())*1000 print(f"{name}: {(1e6/us_per_iter) * 2 * D * D * 4 / 1e9} GB/s") return 0 def cuda_indexing(W, score_idxs, x): return W[score_idxs] @ x def python_indexing(W, score_idxs, x): return W[score_idxs[0]] @ x, W[score_idxs[1]] @ x W = torch.randn(E, D, D) x = torch.randn(D) score_idxs = torch.tensor([3, 5]) compiled_cuda = torch.compile(cuda_indexing, dynamic=False) print(f"D={D}") bench(lambda: python_indexing(W, score_idxs, x), "python indexing") bench(lambda: cuda_indexing(W, score_idxs, x), "eager CUDA indexing") bench(lambda: compiled_cuda(W, score_idxs, x), "compiled CUDA indexing") -
Chillee revised this gist
Feb 26, 2024 . 1 changed file with 45 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,45 @@ import torch from torch import nn import torch.nn.functional as F import torch.autograd as autograd torch.set_default_device('cuda') import torch._inductor.config torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.assert_indirect_indexing = False D = 2048 E = 8 def bench(f, name=None, iters=1000, warmup=5, display=True, profile=False): import time from triton.testing import do_bench for _ in range(warmup): f() if profile: with torch.profiler.profile() as prof: f() prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json") us_per_iter = do_bench(lambda: f())*1000 print(f"{name}: {(1e6/us_per_iter) * 2 * D * D * 4 / 1e9} GB/s") return 0 def cuda_indexing(W, score_idxs, x): return W[score_idxs] @ x def python_indexing(W, score_idxs, x): return W[score_idxs[0]] @ x, W[score_idxs[1]] @ x W = torch.randn(E, D, D) x = torch.randn(D) score_idxs = torch.tensor([3, 5]) compiled_gather = torch.compile(gather_moe) bench(lambda: branch_moe(W, score_idxs, x), "python indexing") bench(lambda: gather_moe(W, score_idxs, x), "eager CUDA indexing") bench(lambda: compiled_gather(W, score_idxs, x), "compiled CUDA indexing") -
Chillee revised this gist
Jul 7, 2023 . 2 changed files with 2 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -16,7 +16,7 @@ def f(): if "gemm" in e.name or "triton" in e.name or "gemv" in e.name: print(f"{N}: {e.name}") timer = e.cuda_time/1e3 timer = do_bench(f) iters_per_second = 1e3/timer flops = A.shape[0] * A.shape[1] * B.shape[1] * 2 flops_achieved = iters_per_second * flops/1e12 This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -16,7 +16,7 @@ def f(): if "gemm" in e.name or "triton" in e.name or "gemv" in e.name: print(f"{N}: {e.name}") timer = e.cuda_time/1e3 timer = do_bench(f) iters_per_second = 1e3/timer flops = A.shape[0] * A.shape[1] * B.shape[1] * 2 flops_achieved = iters_per_second * flops/1e12 -
Chillee renamed this gist
Feb 21, 2023 . 1 changed file with 0 additions and 0 deletions.There are no files selected for viewing
File renamed without changes. -
Chillee revised this gist
Feb 21, 2023 . 1 changed file with 26 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,26 @@ import torch from triton.testing import do_bench def get_flops(N, get_kernels=False): A = torch.randn(N, N, device='cuda', dtype=torch.float16) B = torch.randn(N, N, device='cuda', dtype=torch.float16) def f(): return torch.mm(A, B) if get_kernels: with torch.profiler.profile() as prof: f() for e in prof.events(): if "gemm" in e.name or "triton" in e.name or "gemv" in e.name: print(f"{N}: {e.name}") timer = e.cuda_time/1e3 timer = do_bench(f)[0] iters_per_second = 1e3/timer flops = A.shape[0] * A.shape[1] * B.shape[1] * 2 flops_achieved = iters_per_second * flops/1e12 print(f"{N}: {flops_achieved:.2f}TF/s") for N in range(1, 4096): get_flops(N) -
Chillee revised this gist
Feb 21, 2023 . 1 changed file with 26 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,26 @@ import torch from triton.testing import do_bench def get_flops(N, get_kernels=False): A = torch.randn(N, N, device='cuda', dtype=torch.float16) B = torch.randn(N, N, device='cuda', dtype=torch.float16) def f(): return torch.mm(A, B) if get_kernels: with torch.profiler.profile() as prof: f() for e in prof.events(): if "gemm" in e.name or "triton" in e.name or "gemv" in e.name: print(f"{N}: {e.name}") timer = e.cuda_time/1e3 timer = do_bench(f)[0] iters_per_second = 1e3/timer flops = A.shape[0] * A.shape[1] * B.shape[1] * 2 flops_achieved = iters_per_second * flops/1e12 print(f"{N}: {flops_achieved:.2f}TF/s") for N in range(1, 4096): get_flops(N) -
Chillee revised this gist
Feb 1, 2023 . 1 changed file with 57 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,57 @@ import torch torch.set_float32_matmul_precision('high') import torch._inductor.config torch._inductor.config.debug = True def bench(f, name=None, iters=100, warmup=5, display=True, profile=False): import time for _ in range(warmup): f() if profile: with torch.profiler.profile() as prof: f() prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json") torch.cuda.synchronize() begin = time.time() for _ in range(iters): f() torch.cuda.synchronize() us_per_iter = (time.time()-begin)*1e6/iters if name is None: res = us_per_iter else: res= f"{name}: {us_per_iter:.3f}us" if display: print(res) return res def get_bandwidth(name, f): iters_per_second = 1e6/bench(f, display=False) bytes_accessed = N**2*4*3 print(f"{name}: {iters_per_second * bytes_accessed/1e9:.2f}GB") N = 2**14 def f(a, b): return a + b A = torch.randn(N, N, device='cuda') B = torch.randn(N, N, device='cuda') # eager: 1389.84GB get_bandwidth("eager", lambda: f(A, B)) # torch.compile: 1388.19GB get_bandwidth("torch.compile", lambda: torch.compile(f)(A, B)) def f2(a, b): return a + b.t() A = torch.randn(N, N, device='cuda') B = torch.randn(N, N, device='cuda') # eager: 904.01GB get_bandwidth("eager", lambda: f2(A, B)) # torch.compile: 1334.89GB get_bandwidth("torch.compile", lambda: torch.compile(f2)(A, B)) -
Chillee revised this gist
Jan 21, 2023 . 1 changed file with 42 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,42 @@ import torch from torch.nn import * torch.set_float32_matmul_precision('high') def bench(f, name=None, iters=100, warmup=5, display=True, profile=False): import time for _ in range(warmup): f() if profile: with torch.profiler.profile() as prof: f() prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json") torch.cuda.synchronize() begin = time.time() for _ in range(iters): f() torch.cuda.synchronize() us_per_iter = (time.time()-begin)*1e6/iters if name is None: res = us_per_iter else: res= f"{name}: {us_per_iter:.2f}us" if display: print(res) return res import torchvision.models as models mod = models.resnet18().eval().cuda() opt_mod = torch.compile(mod, mode="reduce-overhead") inp = torch.randn(1, 3, 224, 224).cuda() with torch.no_grad(): # Eager: 1938.18us bench(lambda: mod(inp), "Eager") # torch.compile (default): 953.96us # torch.compile (reduce-overhead): 744.02us bench(lambda: opt_mod(inp), "torch.compile (reduce-overhead)") -
Chillee revised this gist
Dec 12, 2022 . 1 changed file with 2 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,5 +1,7 @@ import torch import torch._inductor.config import time torch._inductor.config.triton.cudagraphs = False torch.set_float32_matmul_precision('high') -
Chillee revised this gist
Dec 12, 2022 . 1 changed file with 1 addition and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,4 +1,5 @@ import torch import torch._inductor.config torch._inductor.config.triton.cudagraphs = False torch.set_float32_matmul_precision('high') -
Chillee created this gist
Dec 12, 2022 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,38 @@ import torch torch._inductor.config.triton.cudagraphs = False torch.set_float32_matmul_precision('high') def bench(f, name=None, iters=100, warmup=5, display=True, profile=False): for _ in range(warmup): f() if profile: with torch.profiler.profile() as prof: f() prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json") torch.cuda.synchronize() begin = time.time() for _ in range(iters): f() torch.cuda.synchronize() us_per_iter = (time.time()-begin)*1e6/iters if name is None: res = us_per_iter else: res= f"{name}: {us_per_iter}us" if display: print(res) return res def f1(a, b, c, d): a = a.relu() b = b.tanh() e = a * b f = (c + 2).cos() return (e + f) * d inp = [torch.randn(2**24, device='cuda') for _ in range(4)] f = f1 nf = torch.compile(f) bench(lambda: f(*inp), name="eager") bench(lambda: nf(*inp), name="PT 2.0")