Last active
September 19, 2024 19:32
-
-
Save zxgx/2874bc7056dd2ec1b8d7ab5570deddd0 to your computer and use it in GitHub Desktop.
Revisions
-
zxgx renamed this gist
Sep 19, 2024 . 1 changed file with 0 additions and 0 deletions.There are no files selected for viewing
File renamed without changes. -
zxgx created this gist
Sep 19, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,141 @@ import time import numpy as np import torch from itertools import accumulate from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func device = torch.device('cuda:0') dtype = torch.bfloat16 num_head = 16 head_dim = 72 DEFAULT_AR_MAP = { "144p": (144, 256), "240p": (240, 426), "360p": (360, 640), "480p": (480, 854), "720p": (720, 1280), "1080p": (1080, 1920), "2k": (1440, 2560), } len_map = len(DEFAULT_AR_MAP) keys = list(DEFAULT_AR_MAP.keys()) for i in range(len_map): for j in range(i+1, len_map): for nf1 in [1, 4, 8, 16]: for nf2 in [1, 4, 8, 16]: ar1 = DEFAULT_AR_MAP[keys[i]] ar2 = DEFAULT_AR_MAP[keys[j]] seq_len = [ ar1[0]//16*ar1[1]//16, # 360p ar2[0]//16*ar2[1]//16 # 1080p ] bs = [nf1, nf2] batch_lens = [] for _len, _bs in zip(seq_len, bs): batch_lens += [_len] * _bs max_seqlen = max(batch_lens) batch_lens = list(accumulate(batch_lens, initial=0)) seqlens = torch.tensor(batch_lens, dtype=torch.int32, device=device) qkv_base = torch.empty((seqlens[-1], 3, num_head, head_dim), dtype=dtype, device=device) grad_base = torch.empty((seqlens[-1], num_head, head_dim), dtype=dtype, device=device) fwd_times, bwd_times = [], [] # warmup for _ in range(2): torch.rand(qkv_base.size(), device=device, dtype=dtype, out=qkv_base) qkv = qkv_base.detach().requires_grad_(True) torch.rand(grad_base.size(), device=device, dtype=dtype, out=grad_base) grad = grad_base.detach() assert qkv.requires_grad out = flash_attn_varlen_qkvpacked_func(qkv, seqlens, max_seqlen) out.backward(grad) for _ in range(10): torch.rand(qkv_base.size(), device=device, dtype=dtype, out=qkv_base) qkv = qkv_base.detach().requires_grad_(True) torch.rand(grad_base.size(), device=device, dtype=dtype, out=grad_base) grad = grad_base.detach() torch.cuda.synchronize() start = time.time() out = flash_attn_varlen_qkvpacked_func(qkv, seqlens, max_seqlen) torch.cuda.synchronize() fwd_times.append(time.time() - start) start = time.time() out.backward(grad) torch.cuda.synchronize() bwd_times.append(time.time() - start) print(f"({keys[i]}*{nf1}, {keys[j]}*{nf2}) varlen:\n - fwd: {np.mean(fwd_times):.4f} ~ {np.std(fwd_times):.4f}\n" f" - bwd: {np.mean(bwd_times):.4f} ~ {np.std(bwd_times):.4f}") qkv1_base = torch.empty((bs[0], seq_len[0], 3, num_head, head_dim), dtype=dtype, device=device) qkv2_base = torch.empty((bs[1], seq_len[1], 3, num_head, head_dim), dtype=dtype, device=device) grad1_base = torch.empty((bs[0], seq_len[0], num_head, head_dim), dtype=dtype, device=device) grad2_base = torch.empty((bs[1], seq_len[1], num_head, head_dim), dtype=dtype, device=device) fwd_times, bwd_times = [], [] for _ in range(2): torch.rand(qkv1_base.size(), device=device, dtype=dtype, out=qkv1_base) torch.rand(qkv2_base.size(), device=device, dtype=dtype, out=qkv2_base) torch.rand(grad1_base.size(), device=device, dtype=dtype, out=grad1_base) torch.rand(grad2_base.size(), device=device, dtype=dtype, out=grad2_base) qkv1 = qkv1_base.detach().requires_grad_(True) qkv2 = qkv2_base.detach().requires_grad_(True) grad1 = grad1_base.detach() grad2 = grad2_base.detach() out = flash_attn_qkvpacked_func(qkv1) out.backward(grad1) out = flash_attn_qkvpacked_func(qkv2) out.backward(grad2) for _ in range(10): torch.rand(qkv1_base.size(), device=device, dtype=dtype, out=qkv1_base) torch.rand(qkv2_base.size(), device=device, dtype=dtype, out=qkv2_base) torch.rand(grad1_base.size(), device=device, dtype=dtype, out=grad1_base) torch.rand(grad2_base.size(), device=device, dtype=dtype, out=grad2_base) qkv1 = qkv1_base.detach().requires_grad_(True) qkv2 = qkv2_base.detach().requires_grad_(True) grad1 = grad1_base.detach() grad2 = grad2_base.detach() torch.cuda.synchronize() start = time.time() out = flash_attn_qkvpacked_func(qkv1) torch.cuda.synchronize() fwd_time = (time.time() - start) start = time.time() out.backward(grad1) torch.cuda.synchronize() bwd_time = (time.time() - start) torch.cuda.synchronize() start = time.time() out = flash_attn_qkvpacked_func(qkv2) torch.cuda.synchronize() fwd_time += (time.time() - start) start = time.time() out.backward(grad2) torch.cuda.synchronize() bwd_time += (time.time() - start) fwd_times.append(fwd_time) bwd_times.append(bwd_time) print(f"({keys[i]}*{nf1}, {keys[j]}*{nf2}) vanilla:\n - fwd: {np.mean(fwd_times):.4f} ~ {np.std(fwd_times):.4f}\n" f" - bwd: {np.mean(bwd_times):.4f} ~ {np.std(bwd_times):.4f}")