Skip to content

Instantly share code, notes, and snippets.

@rishabh135
Forked from Chillee/merge_attention.py
Created February 17, 2025 06:13
Show Gist options
  • Select an option

  • Save rishabh135/d4bfb63dd9759f229b9d59a8988da611 to your computer and use it in GitHub Desktop.

Select an option

Save rishabh135/d4bfb63dd9759f229b9d59a8988da611 to your computer and use it in GitHub Desktop.

Revisions

  1. @Chillee Chillee revised this gist Feb 5, 2025. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions merge_attention.py
    Original file line number Diff line number Diff line change
    @@ -10,6 +10,7 @@
    causal_out, causal_lse = flex_attention(q, k, v, block_mask=causal_mask, return_lse=True)
    uncausal_out, uncausal_lse = flex_attention(q, k, v, block_mask=uncausal_mask, return_lse=True)

    # merge_attention(*attention(q, k1, v1), *attention(q, k2, v2)) == attention(q, cat(k1, k2), cat(v1, v2))
    def merge_attention(a, lse_a, b, lse_b):
    max_lse = torch.maximum(lse_a, lse_b)
    lse_a = torch.exp(lse_a - max_lse)
  2. @Chillee Chillee created this gist Feb 5, 2025.
    27 changes: 27 additions & 0 deletions merge_attention.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,27 @@
    import torch
    from torch.nn.attention.flex_attention import create_block_mask, flex_attention
    torch.set_default_device('cuda')

    q, k, v = [torch.randn(8, 8, 1024, 64, requires_grad=True) for _ in range(3)]

    causal_mask = create_block_mask(lambda b, h, q_idx, kv_idx: q_idx >= kv_idx, None, None, 1024, 1024)
    uncausal_mask = create_block_mask(lambda b, h, q_idx, kv_idx: q_idx < kv_idx, None, None, 1024, 1024)
    ref_out = flex_attention(q, k, v)
    causal_out, causal_lse = flex_attention(q, k, v, block_mask=causal_mask, return_lse=True)
    uncausal_out, uncausal_lse = flex_attention(q, k, v, block_mask=uncausal_mask, return_lse=True)

    def merge_attention(a, lse_a, b, lse_b):
    max_lse = torch.maximum(lse_a, lse_b)
    lse_a = torch.exp(lse_a - max_lse)
    lse_b = torch.exp(lse_b - max_lse)
    out = ((a * lse_a[..., None] + b * lse_b[..., None]) / (lse_a + lse_b)[..., None])
    return out

    merge_out = merge_attention(causal_out, causal_lse, uncausal_out, uncausal_lse)
    assert (ref_out - merge_out).abs().max() < 1e-5

    ref_out.sum().backward()
    ref_q_grad = q.grad
    q.grad = None
    merge_out.sum().backward()
    assert (q.grad - ref_q_grad).abs().max() < 1e-5