- 
      
- 
        Save jeromeku/eb6f53f83c0f2f7b15926ae8b870351b to your computer and use it in GitHub Desktop. 
    Random Kernel Microbenchmarks
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import argparse | |
| import time | |
| from typing import Type | |
| import torch | |
| import torch.nn.functional as F | |
| import torch._inductor.config | |
| torch._inductor.config.triton.multi_kernel = True | |
| from triton.testing import do_bench | |
| import cutlass | |
| import cutlass.torch as cutlass_torch | |
| from quack.softmax import softmax | |
| def run_softmax( | |
| M, | |
| N, | |
| dtype: Type[cutlass.Numeric], | |
| warmup_iterations=10, | |
| iterations=1000, | |
| ): | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError(f"Ampere GPU is required to run this example!") | |
| print(f"Tensor dimensions: [{M}, {N}]") | |
| print(f"Input and Output Data type: {dtype}") | |
| torch_dtype = cutlass_torch.dtype(dtype) | |
| device = "cuda" | |
| x = 0.1 * torch.randn(M, N, device=device, dtype=torch_dtype) | |
| print(f"Input tensor shapes:") | |
| print(f"x: {x.shape}, dtype: {x.dtype}") | |
| out = softmax(x) | |
| torch._dynamo.config.recompile_limit = 1024 | |
| compiled_func_ref = torch.compile(lambda x: F.softmax(x, dim=-1), dynamic=False, mode="max-autotune-no-cudagraphs") | |
| fn = lambda: softmax(x) | |
| time.sleep(0.5) | |
| avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) | |
| mem_bw = round(2 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9) | |
| print(f"Kernel execution time: {avg_time:.4f} ms") | |
| print(f"Mem throughput: {mem_bw:.2f} GB/s") | |
| fn = lambda: compiled_func_ref(x) | |
| for _ in range(5): fn() # warm up | |
| time.sleep(0.5) | |
| avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) | |
| mem_bw_ref = round(2 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9) | |
| print(f"Ref kernel execution time: {avg_time:.4f} ms") | |
| print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s") | |
| return mem_bw, mem_bw_ref | |
| def run_softmax_backward( | |
| M, | |
| N, | |
| dtype: Type[cutlass.Numeric], | |
| warmup_iterations=10, | |
| iterations=1000, | |
| ): | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError(f"Ampere GPU is required to run this example!") | |
| print(f"Tensor dimensions: [{M}, {N}]") | |
| print(f"Input and Output Data type: {dtype}") | |
| torch_dtype = cutlass_torch.dtype(dtype) | |
| device = "cuda" | |
| x = 0.1 * torch.randn(M, N, device=device, dtype=torch_dtype, requires_grad=True) | |
| x_ref = x.detach().clone().requires_grad_() | |
| print(f"Input tensor shapes:") | |
| print(f"x: {x.shape}, dtype: {x.dtype}") | |
| y = softmax(x) | |
| dy = torch.randn_like(y) | |
| time.sleep(0.5) | |
| fn = lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) | |
| avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) | |
| # Memory: read dy and y, write ax backward | |
| mem_bw = round(3 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9) | |
| print(f"Kernel execution time: {avg_time:.4f} ms") | |
| print(f"Mem throughput: {mem_bw:.2f} GB/s") | |
| # Reference implementation | |
| y_ref = F.softmax(x_ref, dim=-1) | |
| compiled_func_ref = torch.compile(lambda: torch.autograd.grad(y_ref, x_ref, grad_outputs=dy, retain_graph=True)) | |
| for _ in range(5): compiled_func_ref() # warm up | |
| time.sleep(0.5) | |
| avg_time_ref = do_bench(compiled_func_ref, warmup=warmup_iterations, rep=iterations) | |
| mem_bw_ref = round(3 * x.numel() * dtype.width // 8 / (avg_time_ref / 1000) / 1e9) | |
| print(f"Ref kernel execution time: {avg_time_ref:.4f} ms") | |
| print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s") | |
| return mem_bw, mem_bw_ref | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Benchmark softmax forward and backward passes" | |
| ) | |
| parser.add_argument("--M", default=8192, type=int) | |
| parser.add_argument("--N", default=16384, type=int) | |
| parser.add_argument("--dtype", type=cutlass.dtype, choices=[cutlass.BFloat16, cutlass.Float16, cutlass.Float32], default=cutlass.BFloat16) | |
| parser.add_argument("--warmup_iterations", default=10, type=int) | |
| parser.add_argument("--iterations", default=100, type=int) | |
| parser.add_argument("--backward", action="store_true", help="Benchmark backward pass instead of forward pass") | |
| args = parser.parse_args() | |
| torch.manual_seed(0) | |
| # if args.backward: | |
| # print("=== Softmax Backward Pass Benchmark ===") | |
| # run_softmax_backward( | |
| # args.M, | |
| # args.N, | |
| # dtype=args.dtype, | |
| # warmup_iterations=args.warmup_iterations, | |
| # iterations=args.iterations, | |
| # ) | |
| # else: | |
| # print("=== Softmax Forward Pass Benchmark ===") | |
| # run_softmax( | |
| # args.M, | |
| # args.N, | |
| # dtype=args.dtype, | |
| # warmup_iterations=args.warmup_iterations, | |
| # iterations=args.iterations, | |
| # ) | |
| # exit(0) | |
| MN_pairs = [(32768, 256), (32768, 512), (32768, 1024), (32768, 2048), (32768, 4096), (32768, 8192), (32768, 16384), (32768, 32768), (32768, 65536), (16384, 131072), (8192, 262144)] | |
| # MN_pairs = [(32768, 32768)] | |
| # # MN_pairs = [(32768, 1024)] | |
| results = [] | |
| for M, N in MN_pairs: | |
| res = run_softmax( | |
| M, | |
| N, | |
| dtype=args.dtype, | |
| warmup_iterations=args.warmup_iterations, | |
| iterations=args.iterations, | |
| ) | |
| results.append(res) | |
| print(results) | |
| # print([x for x, _ in results]) | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment