Skip to content

Instantly share code, notes, and snippets.

@yiliu30
Forked from Chillee/1-pw_op_fusion.py
Created July 7, 2024 04:12
Show Gist options
  • Select an option

  • Save yiliu30/f0acb7212ee82189aeb4e03f1bc7ad1f to your computer and use it in GitHub Desktop.

Select an option

Save yiliu30/f0acb7212ee82189aeb4e03f1bc7ad1f to your computer and use it in GitHub Desktop.

Revisions

  1. @Chillee Chillee revised this gist Feb 26, 2024. 1 changed file with 25 additions and 23 deletions.
    48 changes: 25 additions & 23 deletions 5-moe-poc.py
    Original file line number Diff line number Diff line change
    @@ -11,35 +11,37 @@

    D = 2048
    E = 8
    def bench(f, name=None, iters=1000, warmup=5, display=True, profile=False):
    import time
    from triton.testing import do_bench

    for _ in range(warmup):
    f()
    if profile:
    with torch.profiler.profile() as prof:
    for D in [1024, 2048, 4096, 8192, 16384]:
    def bench(f, name=None, iters=1000, warmup=5, display=True, profile=False):
    import time
    from triton.testing import do_bench

    for _ in range(warmup):
    f()
    prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")
    if profile:
    with torch.profiler.profile() as prof:
    f()
    prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")

    us_per_iter = do_bench(lambda: f())*1000
    print(f"{name}: {(1e6/us_per_iter) * 2 * D * D * 4 / 1e9} GB/s")
    us_per_iter = do_bench(lambda: f())*1000
    print(f"{name}: {(1e6/us_per_iter) * 2 * D * D * 4 / 1e9} GB/s")

    return 0
    return 0



    def cuda_indexing(W, score_idxs, x):
    return W[score_idxs] @ x
    def cuda_indexing(W, score_idxs, x):
    return W[score_idxs] @ x

    def python_indexing(W, score_idxs, x):
    return W[score_idxs[0]] @ x, W[score_idxs[1]] @ x
    def python_indexing(W, score_idxs, x):
    return W[score_idxs[0]] @ x, W[score_idxs[1]] @ x

    W = torch.randn(E, D, D)
    x = torch.randn(D)
    score_idxs = torch.tensor([3, 5])
    W = torch.randn(E, D, D)
    x = torch.randn(D)
    score_idxs = torch.tensor([3, 5])

    compiled_gather = torch.compile(gather_moe)
    bench(lambda: branch_moe(W, score_idxs, x), "python indexing")
    bench(lambda: gather_moe(W, score_idxs, x), "eager CUDA indexing")
    bench(lambda: compiled_gather(W, score_idxs, x), "compiled CUDA indexing")
    compiled_cuda = torch.compile(cuda_indexing, dynamic=False)
    print(f"D={D}")
    bench(lambda: python_indexing(W, score_idxs, x), "python indexing")
    bench(lambda: cuda_indexing(W, score_idxs, x), "eager CUDA indexing")
    bench(lambda: compiled_cuda(W, score_idxs, x), "compiled CUDA indexing")
  2. @Chillee Chillee revised this gist Feb 26, 2024. 1 changed file with 45 additions and 0 deletions.
    45 changes: 45 additions & 0 deletions 5-moe-poc.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,45 @@
    import torch
    from torch import nn
    import torch.nn.functional as F
    import torch.autograd as autograd
    torch.set_default_device('cuda')
    import torch._inductor.config

    torch._inductor.config.triton.unique_kernel_names = True
    torch._inductor.config.coordinate_descent_tuning = True
    torch._inductor.config.assert_indirect_indexing = False

    D = 2048
    E = 8
    def bench(f, name=None, iters=1000, warmup=5, display=True, profile=False):
    import time
    from triton.testing import do_bench

    for _ in range(warmup):
    f()
    if profile:
    with torch.profiler.profile() as prof:
    f()
    prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")

    us_per_iter = do_bench(lambda: f())*1000
    print(f"{name}: {(1e6/us_per_iter) * 2 * D * D * 4 / 1e9} GB/s")

    return 0



    def cuda_indexing(W, score_idxs, x):
    return W[score_idxs] @ x

    def python_indexing(W, score_idxs, x):
    return W[score_idxs[0]] @ x, W[score_idxs[1]] @ x

    W = torch.randn(E, D, D)
    x = torch.randn(D)
    score_idxs = torch.tensor([3, 5])

    compiled_gather = torch.compile(gather_moe)
    bench(lambda: branch_moe(W, score_idxs, x), "python indexing")
    bench(lambda: gather_moe(W, score_idxs, x), "eager CUDA indexing")
    bench(lambda: compiled_gather(W, score_idxs, x), "compiled CUDA indexing")
  3. @Chillee Chillee revised this gist Jul 7, 2023. 2 changed files with 2 additions and 2 deletions.
    2 changes: 1 addition & 1 deletion 3-matmul-bench.py
    Original file line number Diff line number Diff line change
    @@ -16,7 +16,7 @@ def f():
    if "gemm" in e.name or "triton" in e.name or "gemv" in e.name:
    print(f"{N}: {e.name}")
    timer = e.cuda_time/1e3
    timer = do_bench(f)[0]
    timer = do_bench(f)
    iters_per_second = 1e3/timer
    flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
    flops_achieved = iters_per_second * flops/1e12
    2 changes: 1 addition & 1 deletion 4-matmul-bench.py
    Original file line number Diff line number Diff line change
    @@ -16,7 +16,7 @@ def f():
    if "gemm" in e.name or "triton" in e.name or "gemv" in e.name:
    print(f"{N}: {e.name}")
    timer = e.cuda_time/1e3
    timer = do_bench(f)[0]
    timer = do_bench(f)
    iters_per_second = 1e3/timer
    flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
    flops_achieved = iters_per_second * flops/1e12
  4. @Chillee Chillee renamed this gist Feb 21, 2023. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  5. @Chillee Chillee revised this gist Feb 21, 2023. 1 changed file with 26 additions and 0 deletions.
    26 changes: 26 additions & 0 deletions 04-matmuls.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,26 @@
    import torch
    from triton.testing import do_bench

    def get_flops(N, get_kernels=False):
    A = torch.randn(N, N, device='cuda', dtype=torch.float16)
    B = torch.randn(N, N, device='cuda', dtype=torch.float16)

    def f():
    return torch.mm(A, B)

    if get_kernels:
    with torch.profiler.profile() as prof:
    f()

    for e in prof.events():
    if "gemm" in e.name or "triton" in e.name or "gemv" in e.name:
    print(f"{N}: {e.name}")
    timer = e.cuda_time/1e3
    timer = do_bench(f)[0]
    iters_per_second = 1e3/timer
    flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
    flops_achieved = iters_per_second * flops/1e12
    print(f"{N}: {flops_achieved:.2f}TF/s")

    for N in range(1, 4096):
    get_flops(N)
  6. @Chillee Chillee revised this gist Feb 21, 2023. 1 changed file with 26 additions and 0 deletions.
    26 changes: 26 additions & 0 deletions 3-matmul-bench.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,26 @@
    import torch
    from triton.testing import do_bench

    def get_flops(N, get_kernels=False):
    A = torch.randn(N, N, device='cuda', dtype=torch.float16)
    B = torch.randn(N, N, device='cuda', dtype=torch.float16)

    def f():
    return torch.mm(A, B)

    if get_kernels:
    with torch.profiler.profile() as prof:
    f()

    for e in prof.events():
    if "gemm" in e.name or "triton" in e.name or "gemv" in e.name:
    print(f"{N}: {e.name}")
    timer = e.cuda_time/1e3
    timer = do_bench(f)[0]
    iters_per_second = 1e3/timer
    flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
    flops_achieved = iters_per_second * flops/1e12
    print(f"{N}: {flops_achieved:.2f}TF/s")

    for N in range(1, 4096):
    get_flops(N)
  7. @Chillee Chillee revised this gist Feb 1, 2023. 1 changed file with 57 additions and 0 deletions.
    57 changes: 57 additions & 0 deletions 3-tiling.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,57 @@
    import torch
    torch.set_float32_matmul_precision('high')
    import torch._inductor.config
    torch._inductor.config.debug = True


    def bench(f, name=None, iters=100, warmup=5, display=True, profile=False):
    import time
    for _ in range(warmup):
    f()
    if profile:
    with torch.profiler.profile() as prof:
    f()
    prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")

    torch.cuda.synchronize()
    begin = time.time()
    for _ in range(iters):
    f()
    torch.cuda.synchronize()
    us_per_iter = (time.time()-begin)*1e6/iters

    if name is None:
    res = us_per_iter
    else:
    res= f"{name}: {us_per_iter:.3f}us"

    if display:
    print(res)
    return res

    def get_bandwidth(name, f):
    iters_per_second = 1e6/bench(f, display=False)
    bytes_accessed = N**2*4*3
    print(f"{name}: {iters_per_second * bytes_accessed/1e9:.2f}GB")

    N = 2**14

    def f(a, b):
    return a + b
    A = torch.randn(N, N, device='cuda')
    B = torch.randn(N, N, device='cuda')

    # eager: 1389.84GB
    get_bandwidth("eager", lambda: f(A, B))
    # torch.compile: 1388.19GB
    get_bandwidth("torch.compile", lambda: torch.compile(f)(A, B))

    def f2(a, b):
    return a + b.t()
    A = torch.randn(N, N, device='cuda')
    B = torch.randn(N, N, device='cuda')

    # eager: 904.01GB
    get_bandwidth("eager", lambda: f2(A, B))
    # torch.compile: 1334.89GB
    get_bandwidth("torch.compile", lambda: torch.compile(f2)(A, B))
  8. @Chillee Chillee revised this gist Jan 21, 2023. 1 changed file with 42 additions and 0 deletions.
    42 changes: 42 additions & 0 deletions 2-overhead.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,42 @@
    import torch
    from torch.nn import *
    torch.set_float32_matmul_precision('high')

    def bench(f, name=None, iters=100, warmup=5, display=True, profile=False):
    import time
    for _ in range(warmup):
    f()
    if profile:
    with torch.profiler.profile() as prof:
    f()
    prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")

    torch.cuda.synchronize()
    begin = time.time()
    for _ in range(iters):
    f()
    torch.cuda.synchronize()
    us_per_iter = (time.time()-begin)*1e6/iters
    if name is None:
    res = us_per_iter
    else:
    res= f"{name}: {us_per_iter:.2f}us"
    if display:
    print(res)
    return res


    import torchvision.models as models

    mod = models.resnet18().eval().cuda()
    opt_mod = torch.compile(mod, mode="reduce-overhead")

    inp = torch.randn(1, 3, 224, 224).cuda()


    with torch.no_grad():
    # Eager: 1938.18us
    bench(lambda: mod(inp), "Eager")
    # torch.compile (default): 953.96us
    # torch.compile (reduce-overhead): 744.02us
    bench(lambda: opt_mod(inp), "torch.compile (reduce-overhead)")
  9. @Chillee Chillee revised this gist Dec 12, 2022. 1 changed file with 2 additions and 0 deletions.
    2 changes: 2 additions & 0 deletions 1-pw_op_fusion.py
    Original file line number Diff line number Diff line change
    @@ -1,5 +1,7 @@
    import torch
    import torch._inductor.config
    import time

    torch._inductor.config.triton.cudagraphs = False
    torch.set_float32_matmul_precision('high')

  10. @Chillee Chillee revised this gist Dec 12, 2022. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions 1-pw_op_fusion.py
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,5 @@
    import torch
    import torch._inductor.config
    torch._inductor.config.triton.cudagraphs = False
    torch.set_float32_matmul_precision('high')

  11. @Chillee Chillee created this gist Dec 12, 2022.
    38 changes: 38 additions & 0 deletions 1-pw_op_fusion.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,38 @@
    import torch
    torch._inductor.config.triton.cudagraphs = False
    torch.set_float32_matmul_precision('high')

    def bench(f, name=None, iters=100, warmup=5, display=True, profile=False):
    for _ in range(warmup):
    f()
    if profile:
    with torch.profiler.profile() as prof:
    f()
    prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")

    torch.cuda.synchronize()
    begin = time.time()
    for _ in range(iters):
    f()
    torch.cuda.synchronize()
    us_per_iter = (time.time()-begin)*1e6/iters
    if name is None:
    res = us_per_iter
    else:
    res= f"{name}: {us_per_iter}us"
    if display:
    print(res)
    return res

    def f1(a, b, c, d):
    a = a.relu()
    b = b.tanh()
    e = a * b
    f = (c + 2).cos()
    return (e + f) * d
    inp = [torch.randn(2**24, device='cuda') for _ in range(4)]

    f = f1
    nf = torch.compile(f)
    bench(lambda: f(*inp), name="eager")
    bench(lambda: nf(*inp), name="PT 2.0")