Skip to content

Instantly share code, notes, and snippets.

@zxgx
Last active September 19, 2024 19:32
Show Gist options
  • Select an option

  • Save zxgx/2874bc7056dd2ec1b8d7ab5570deddd0 to your computer and use it in GitHub Desktop.

Select an option

Save zxgx/2874bc7056dd2ec1b8d7ab5570deddd0 to your computer and use it in GitHub Desktop.

Revisions

  1. zxgx renamed this gist Sep 19, 2024. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  2. zxgx created this gist Sep 19, 2024.
    141 changes: 141 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,141 @@
    import time
    import numpy as np
    import torch
    from itertools import accumulate

    from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func


    device = torch.device('cuda:0')
    dtype = torch.bfloat16

    num_head = 16
    head_dim = 72

    DEFAULT_AR_MAP = {
    "144p": (144, 256),
    "240p": (240, 426),
    "360p": (360, 640),
    "480p": (480, 854),
    "720p": (720, 1280),
    "1080p": (1080, 1920),
    "2k": (1440, 2560),
    }
    len_map = len(DEFAULT_AR_MAP)
    keys = list(DEFAULT_AR_MAP.keys())
    for i in range(len_map):
    for j in range(i+1, len_map):
    for nf1 in [1, 4, 8, 16]:
    for nf2 in [1, 4, 8, 16]:

    ar1 = DEFAULT_AR_MAP[keys[i]]
    ar2 = DEFAULT_AR_MAP[keys[j]]

    seq_len = [
    ar1[0]//16*ar1[1]//16, # 360p
    ar2[0]//16*ar2[1]//16 # 1080p
    ]
    bs = [nf1, nf2]

    batch_lens = []
    for _len, _bs in zip(seq_len, bs):
    batch_lens += [_len] * _bs

    max_seqlen = max(batch_lens)
    batch_lens = list(accumulate(batch_lens, initial=0))

    seqlens = torch.tensor(batch_lens, dtype=torch.int32, device=device)
    qkv_base = torch.empty((seqlens[-1], 3, num_head, head_dim), dtype=dtype, device=device)
    grad_base = torch.empty((seqlens[-1], num_head, head_dim), dtype=dtype, device=device)

    fwd_times, bwd_times = [], []
    # warmup
    for _ in range(2):
    torch.rand(qkv_base.size(), device=device, dtype=dtype, out=qkv_base)
    qkv = qkv_base.detach().requires_grad_(True)
    torch.rand(grad_base.size(), device=device, dtype=dtype, out=grad_base)
    grad = grad_base.detach()

    assert qkv.requires_grad
    out = flash_attn_varlen_qkvpacked_func(qkv, seqlens, max_seqlen)
    out.backward(grad)

    for _ in range(10):
    torch.rand(qkv_base.size(), device=device, dtype=dtype, out=qkv_base)
    qkv = qkv_base.detach().requires_grad_(True)
    torch.rand(grad_base.size(), device=device, dtype=dtype, out=grad_base)
    grad = grad_base.detach()

    torch.cuda.synchronize()
    start = time.time()
    out = flash_attn_varlen_qkvpacked_func(qkv, seqlens, max_seqlen)
    torch.cuda.synchronize()
    fwd_times.append(time.time() - start)

    start = time.time()
    out.backward(grad)
    torch.cuda.synchronize()
    bwd_times.append(time.time() - start)

    print(f"({keys[i]}*{nf1}, {keys[j]}*{nf2}) varlen:\n - fwd: {np.mean(fwd_times):.4f} ~ {np.std(fwd_times):.4f}\n"
    f" - bwd: {np.mean(bwd_times):.4f} ~ {np.std(bwd_times):.4f}")

    qkv1_base = torch.empty((bs[0], seq_len[0], 3, num_head, head_dim), dtype=dtype, device=device)
    qkv2_base = torch.empty((bs[1], seq_len[1], 3, num_head, head_dim), dtype=dtype, device=device)
    grad1_base = torch.empty((bs[0], seq_len[0], num_head, head_dim), dtype=dtype, device=device)
    grad2_base = torch.empty((bs[1], seq_len[1], num_head, head_dim), dtype=dtype, device=device)

    fwd_times, bwd_times = [], []
    for _ in range(2):
    torch.rand(qkv1_base.size(), device=device, dtype=dtype, out=qkv1_base)
    torch.rand(qkv2_base.size(), device=device, dtype=dtype, out=qkv2_base)
    torch.rand(grad1_base.size(), device=device, dtype=dtype, out=grad1_base)
    torch.rand(grad2_base.size(), device=device, dtype=dtype, out=grad2_base)
    qkv1 = qkv1_base.detach().requires_grad_(True)
    qkv2 = qkv2_base.detach().requires_grad_(True)
    grad1 = grad1_base.detach()
    grad2 = grad2_base.detach()

    out = flash_attn_qkvpacked_func(qkv1)
    out.backward(grad1)

    out = flash_attn_qkvpacked_func(qkv2)
    out.backward(grad2)

    for _ in range(10):
    torch.rand(qkv1_base.size(), device=device, dtype=dtype, out=qkv1_base)
    torch.rand(qkv2_base.size(), device=device, dtype=dtype, out=qkv2_base)
    torch.rand(grad1_base.size(), device=device, dtype=dtype, out=grad1_base)
    torch.rand(grad2_base.size(), device=device, dtype=dtype, out=grad2_base)
    qkv1 = qkv1_base.detach().requires_grad_(True)
    qkv2 = qkv2_base.detach().requires_grad_(True)
    grad1 = grad1_base.detach()
    grad2 = grad2_base.detach()

    torch.cuda.synchronize()
    start = time.time()
    out = flash_attn_qkvpacked_func(qkv1)
    torch.cuda.synchronize()
    fwd_time = (time.time() - start)

    start = time.time()
    out.backward(grad1)
    torch.cuda.synchronize()
    bwd_time = (time.time() - start)

    torch.cuda.synchronize()
    start = time.time()
    out = flash_attn_qkvpacked_func(qkv2)
    torch.cuda.synchronize()
    fwd_time += (time.time() - start)

    start = time.time()
    out.backward(grad2)
    torch.cuda.synchronize()
    bwd_time += (time.time() - start)

    fwd_times.append(fwd_time)
    bwd_times.append(bwd_time)

    print(f"({keys[i]}*{nf1}, {keys[j]}*{nf2}) vanilla:\n - fwd: {np.mean(fwd_times):.4f} ~ {np.std(fwd_times):.4f}\n"
    f" - bwd: {np.mean(bwd_times):.4f} ~ {np.std(bwd_times):.4f}")