import torch import torch.nn.functional as F from torch.nn.attention.flex_attention import flex_attention, create_block_mask import torch head_num = 16 dim = 128 seq_len = 100 chunk_size = 5 batch_size = 1 q = torch.randn(batch_size, head_num, seq_len, dim, requires_grad=True).cuda().to(torch.bfloat16) k = torch.randn(batch_size, head_num, seq_len, dim, requires_grad=True).cuda().to(torch.bfloat16) v = torch.randn(batch_size, head_num, seq_len, dim, requires_grad=True).cuda().to(torch.bfloat16) g0 = torch.randn(batch_size, head_num, seq_len, dim, dtype=torch.bfloat16, device='cuda') g1 = torch.randn(batch_size, head_num, seq_len, dtype=torch.bfloat16, device='cuda') q.retain_grad() k.retain_grad() v.retain_grad() actual_out, actual_lse = flex_attention(q, k, v, block_mask = None, return_lse=True) (actual_out.grad_fn.saved_tensors[3] == actual_out).float().mean() (actual_out.grad_fn.saved_tensors[4] == actual_lse).float().mean()