Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active February 6, 2025 10:14
Show Gist options
  • Save Chillee/e3089e7a11419c6b85f68de170e0ba0c to your computer and use it in GitHub Desktop.
Save Chillee/e3089e7a11419c6b85f68de170e0ba0c to your computer and use it in GitHub Desktop.

Revisions

  1. Chillee revised this gist May 31, 2024. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions assoc_scan.py
    Original file line number Diff line number Diff line change
    @@ -2,6 +2,7 @@
    import torch.nn as nn
    from torch._higher_order_ops.associative_scan import associative_scan
    from triton.testing import do_bench
    torch.set_default_device('cuda')

    def combine_fn(i, j):
    ia, ib = i
  2. Chillee revised this gist May 18, 2024. 1 changed file with 2 additions and 2 deletions.
    4 changes: 2 additions & 2 deletions assoc_scan.py
    Original file line number Diff line number Diff line change
    @@ -8,8 +8,8 @@ def combine_fn(i, j):
    ja, jb = j
    return ia * ja, ib * ja + jb

    a = torch.randn(1024, 1024)
    b = torch.randn(1024, 1024)
    a = torch.randn(1024, 1024 * 10)
    b = torch.randn(1024, 1024 * 10)

    def baseline(v, u):
    A = []
  3. Chillee created this gist May 18, 2024.
    31 changes: 31 additions & 0 deletions assoc_scan.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,31 @@
    import torch
    import torch.nn as nn
    from torch._higher_order_ops.associative_scan import associative_scan
    from triton.testing import do_bench

    def combine_fn(i, j):
    ia, ib = i
    ja, jb = j
    return ia * ja, ib * ja + jb

    a = torch.randn(1024, 1024)
    b = torch.randn(1024, 1024)

    def baseline(v, u):
    A = []
    A.append(b[:, 0])
    for i in range(1, v.shape[1]):
    A.append(a[:, i] * A[i - 1] + b[:, i])
    return torch.stack(A, dim=1)

    @torch.compile
    def compiled_scan(a, b):
    return associative_scan(combine_fn, (a, b), dim=-1)[1]

    out1 = baseline(a, b)
    out2 = compiled_scan(a, b)
    print((out1 - out2).abs().max())

    print("eager", do_bench(lambda: baseline(a, b)))
    print("compiled", do_bench(lambda: compiled_scan(a, b)))
    print("two cumprods", do_bench(lambda: [torch.cumprod(a, dim=-1), torch.cumprod(b, dim=-1)]))