Skip to content

Instantly share code, notes, and snippets.

@Chillee
Created April 12, 2024 05:13
Show Gist options
  • Select an option

  • Save Chillee/41baf11aac8036d25d637321c48dad20 to your computer and use it in GitHub Desktop.

Select an option

Save Chillee/41baf11aac8036d25d637321c48dad20 to your computer and use it in GitHub Desktop.

Revisions

  1. Chillee created this gist Apr 12, 2024.
    26 changes: 26 additions & 0 deletions attention_dim_bench.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,26 @@
    import torch
    from torch.utils.flop_counter import FlopCounterMode
    from triton.testing import do_bench
    torch.set_default_device('cuda')

    def get_flops_achieved(f):
    flop_counter = FlopCounterMode(display=False)
    with flop_counter:
    f()
    total_flops = flop_counter.get_total_flops()
    ms_per_iter = do_bench(f)
    iters_per_second = 1e3/ms_per_iter
    print(f"{iters_per_second * total_flops / 1e12} TF/s")


    def attention(q, k, v):
    return torch.softmax(q @ k.T, dim=-1) @ v

    S = 4096
    D = 256
    for D in [64, 128, 256, 512, 1024]:
    q = torch.randn(S, D, dtype=torch.bfloat16)
    k = torch.randn(S, D, dtype=torch.bfloat16)
    v = torch.randn(S, D, dtype=torch.bfloat16)
    print(f"D={D}")
    get_flops_achieved(lambda: attention(q, k, v))