Skip to content

Instantly share code, notes, and snippets.

@jeromeku
Forked from Chillee/softmax_quack.py
Created July 11, 2025 12:10
Show Gist options
  • Save jeromeku/eb6f53f83c0f2f7b15926ae8b870351b to your computer and use it in GitHub Desktop.
Save jeromeku/eb6f53f83c0f2f7b15926ae8b870351b to your computer and use it in GitHub Desktop.

Revisions

  1. @Chillee Chillee created this gist Jul 10, 2025.
    155 changes: 155 additions & 0 deletions softmax_quack.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,155 @@
    import argparse
    import time
    from typing import Type

    import torch
    import torch.nn.functional as F
    import torch._inductor.config

    torch._inductor.config.triton.multi_kernel = True

    from triton.testing import do_bench

    import cutlass
    import cutlass.torch as cutlass_torch

    from quack.softmax import softmax


    def run_softmax(
    M,
    N,
    dtype: Type[cutlass.Numeric],
    warmup_iterations=10,
    iterations=1000,
    ):
    if not torch.cuda.is_available():
    raise RuntimeError(f"Ampere GPU is required to run this example!")

    print(f"Tensor dimensions: [{M}, {N}]")
    print(f"Input and Output Data type: {dtype}")

    torch_dtype = cutlass_torch.dtype(dtype)

    device = "cuda"
    x = 0.1 * torch.randn(M, N, device=device, dtype=torch_dtype)

    print(f"Input tensor shapes:")
    print(f"x: {x.shape}, dtype: {x.dtype}")
    out = softmax(x)
    torch._dynamo.config.recompile_limit = 1024
    compiled_func_ref = torch.compile(lambda x: F.softmax(x, dim=-1), dynamic=False, mode="max-autotune-no-cudagraphs")
    fn = lambda: softmax(x)
    time.sleep(0.5)
    avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations)
    mem_bw = round(2 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9)
    print(f"Kernel execution time: {avg_time:.4f} ms")
    print(f"Mem throughput: {mem_bw:.2f} GB/s")

    fn = lambda: compiled_func_ref(x)
    for _ in range(5): fn() # warm up
    time.sleep(0.5)
    avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations)
    mem_bw_ref = round(2 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9)
    print(f"Ref kernel execution time: {avg_time:.4f} ms")
    print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s")

    return mem_bw, mem_bw_ref


    def run_softmax_backward(
    M,
    N,
    dtype: Type[cutlass.Numeric],
    warmup_iterations=10,
    iterations=1000,
    ):
    if not torch.cuda.is_available():
    raise RuntimeError(f"Ampere GPU is required to run this example!")

    print(f"Tensor dimensions: [{M}, {N}]")
    print(f"Input and Output Data type: {dtype}")

    torch_dtype = cutlass_torch.dtype(dtype)

    device = "cuda"
    x = 0.1 * torch.randn(M, N, device=device, dtype=torch_dtype, requires_grad=True)
    x_ref = x.detach().clone().requires_grad_()

    print(f"Input tensor shapes:")
    print(f"x: {x.shape}, dtype: {x.dtype}")

    y = softmax(x)
    dy = torch.randn_like(y)

    time.sleep(0.5)
    fn = lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True)
    avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations)
    # Memory: read dy and y, write ax backward
    mem_bw = round(3 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9)
    print(f"Kernel execution time: {avg_time:.4f} ms")
    print(f"Mem throughput: {mem_bw:.2f} GB/s")

    # Reference implementation
    y_ref = F.softmax(x_ref, dim=-1)
    compiled_func_ref = torch.compile(lambda: torch.autograd.grad(y_ref, x_ref, grad_outputs=dy, retain_graph=True))

    for _ in range(5): compiled_func_ref() # warm up
    time.sleep(0.5)
    avg_time_ref = do_bench(compiled_func_ref, warmup=warmup_iterations, rep=iterations)
    mem_bw_ref = round(3 * x.numel() * dtype.width // 8 / (avg_time_ref / 1000) / 1e9)
    print(f"Ref kernel execution time: {avg_time_ref:.4f} ms")
    print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s")

    return mem_bw, mem_bw_ref


    if __name__ == "__main__":
    parser = argparse.ArgumentParser(
    description="Benchmark softmax forward and backward passes"
    )
    parser.add_argument("--M", default=8192, type=int)
    parser.add_argument("--N", default=16384, type=int)
    parser.add_argument("--dtype", type=cutlass.dtype, choices=[cutlass.BFloat16, cutlass.Float16, cutlass.Float32], default=cutlass.BFloat16)
    parser.add_argument("--warmup_iterations", default=10, type=int)
    parser.add_argument("--iterations", default=100, type=int)
    parser.add_argument("--backward", action="store_true", help="Benchmark backward pass instead of forward pass")

    args = parser.parse_args()
    torch.manual_seed(0)

    # if args.backward:
    # print("=== Softmax Backward Pass Benchmark ===")
    # run_softmax_backward(
    # args.M,
    # args.N,
    # dtype=args.dtype,
    # warmup_iterations=args.warmup_iterations,
    # iterations=args.iterations,
    # )
    # else:
    # print("=== Softmax Forward Pass Benchmark ===")
    # run_softmax(
    # args.M,
    # args.N,
    # dtype=args.dtype,
    # warmup_iterations=args.warmup_iterations,
    # iterations=args.iterations,
    # )
    # exit(0)

    MN_pairs = [(32768, 256), (32768, 512), (32768, 1024), (32768, 2048), (32768, 4096), (32768, 8192), (32768, 16384), (32768, 32768), (32768, 65536), (16384, 131072), (8192, 262144)]
    # MN_pairs = [(32768, 32768)]
    # # MN_pairs = [(32768, 1024)]
    results = []
    for M, N in MN_pairs:
    res = run_softmax(
    M,
    N,
    dtype=args.dtype,
    warmup_iterations=args.warmup_iterations,
    iterations=args.iterations,
    )
    results.append(res)
    print(results)
    # print([x for x, _ in results])