Skip to content

Instantly share code, notes, and snippets.

@nil0x9
Last active May 20, 2025 08:37
Show Gist options
  • Save nil0x9/16956ab4b66fcfbf9f81a174fef6bf71 to your computer and use it in GitHub Desktop.
Save nil0x9/16956ab4b66fcfbf9f81a174fef6bf71 to your computer and use it in GitHub Desktop.

Revisions

  1. nil0x9 revised this gist May 20, 2025. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion flash_muon_partial_materialization.py
    Original file line number Diff line number Diff line change
    @@ -9,7 +9,7 @@
    from flash_muon import fast_newtonschulz as fast_newtonschulz_v1
    except ImportError as e:
    print("Failed to import fast_newtonschulz from flash_muon. Please ensure the module is installed:")
    print("\tgit clone https://github.com/nil0x9/flash-muon.git && pip install -e ./")
    print("\tgit clone https://github.com/nil0x9/flash-muon.git && pip install -e flash-muon/")
    import sys
    sys.exit(1)

  2. nil0x9 created this gist May 8, 2025.
    388 changes: 388 additions & 0 deletions flash_muon_partial_materialization.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,388 @@
    #!/usr/bin/env python
    # coding: utf-8
    import pandas as pd
    import triton
    import triton.language as tl
    import torch
    from torch import Tensor
    try:
    from flash_muon import fast_newtonschulz as fast_newtonschulz_v1
    except ImportError as e:
    print("Failed to import fast_newtonschulz from flash_muon. Please ensure the module is installed:")
    print("\tgit clone https://github.com/nil0x9/flash-muon.git && pip install -e ./")
    import sys
    sys.exit(1)

    assert triton.__version__ >= '3.2.0', "This scripts requires triton version >= 3.2.0 to run."

    assert torch.cuda.is_available(), "Need CUDA device to run!"
    current_device = torch.cuda.current_device()
    device_name = torch.cuda.get_device_name(current_device)
    print(f"Current CUDA device: {device_name}")

    print("The scripts takes abit long to run (autotuning for triton kernels). Set TRITON_PRINT_AUTOTUNING=1 to make autotuning verbal.")

    def get_mmt_kernel_autotune_config():
    return [triton.Config({'BLOCK_SIZE_M': blk_m, 'BLOCK_SIZE_K': blk_k, 'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps)
    for blk_m in [32, 64, 128]
    for blk_k in [32, 64]
    for grp_sz in [8]
    for n_stages in [3, 4, 5]
    for n_warps in [2, 4, 8]
    ]

    def get_sym_axpbxx_kernel_autotune_config():
    return [triton.Config({'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps)
    for grp_sz in [4, 8]
    for n_stages in [1, 2, 3, 4, 5]
    for n_warps in [1, 2, 4, 8]
    ]

    def get_sym_aypbxy_kernel_autotune_config():
    return [triton.Config({'BLOCK_SIZE_N': blk_n, 'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps)
    for blk_n in [32, 64, 128]
    for grp_sz in [4, 8]
    for n_stages in [1, 2, 3, 4, 5]
    for n_warps in [1, 2, 4, 8]
    ]


    @triton.autotune(
    configs=get_mmt_kernel_autotune_config(),
    key=['M', 'K'],
    )
    @triton.jit
    def mmt_kernel(
    x, y,
    M, K,
    stride_xm, stride_xk,
    stride_ym, stride_yn,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
    ):
    """
    Core kernel jit function of matmul_transpose that computes y = x @ x.T
    The code is a simple adaptation from the triton `matmul` tutorial:
    https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
    """
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m
    if pid_m > pid_n:
    return

    offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    # we use a & b ptrs to denote different rows of x.
    a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
    b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)

    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
    a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
    b = tl.load(b_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
    accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
    a_ptrs += BLOCK_SIZE_K * stride_xk
    b_ptrs += BLOCK_SIZE_K * stride_xk
    # use dtype.element_ty to accomodate different input datatypes as in cpp templates
    # https://github.com/triton-lang/triton/issues/2252
    c = accumulator.to(x.dtype.element_ty, fp_downcast_rounding='rtne')

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
    tl.store(c_ptrs, c, mask=c_mask)


    @triton.autotune(
    configs=get_sym_axpbxx_kernel_autotune_config(),
    key=['M'],
    )
    @triton.jit
    def sym_axpbxx_kernel(
    x, y,
    alpha, beta,
    M,
    stride_xm, stride_xk,
    stride_ym, stride_yn,
    BLOCK_SIZE_M: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
    ):
    """
    calculate y = alpha * x + beta * x @ x, where x is symetric matrix, alpha & beta are scalars
    """
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m
    if pid_m > pid_n:
    return
    offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_k = tl.arange(0, BLOCK_SIZE_M)
    # we use a & b ptrs to denote different rows of x.
    # a_ptrs_base = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
    # b_ptrs_base = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
    ktile = 0
    for k in range(0, pid_m):
    a_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xm[None, :] * stride_xk)
    b_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xn[None, :] * stride_xk)
    # print(ktile)
    a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0)
    b = tl.load(b_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0)
    accumulator = tl.dot(tl.permute(a, (1, 0)), b, accumulator)
    ktile += BLOCK_SIZE_M

    for k in range(pid_m, pid_n+1):
    a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk)
    b_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xn[None, :] * stride_xk)
    # print(ktile)
    a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0)
    b = tl.load(b_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0)
    accumulator = tl.dot(a, b, accumulator)
    ktile += BLOCK_SIZE_M

    for k in range(pid_n+1, tl.cdiv(M, BLOCK_SIZE_M)):
    a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk)
    b_ptrs = x + (offs_xn[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk)
    # print(ktile)
    a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0)
    b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0)
    accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
    ktile += BLOCK_SIZE_M

    # use dtype.element_ty to accomodate different input datatypes as in cpp templates
    # https://github.com/triton-lang/triton/issues/2252
    c = accumulator.to(x.dtype.element_ty, fp_downcast_rounding='rtne')

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    a_tile_ptrs = x + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
    c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
    a_tile = tl.load(a_tile_ptrs, mask=c_mask, other=0.0)
    c = alpha * a_tile + beta * c
    tl.store(c_ptrs, c, mask=c_mask)

    @triton.autotune(
    configs=get_sym_aypbxy_kernel_autotune_config(),
    key=['M', 'N'],
    )
    @triton.jit
    def sym_aypbxy_kernel(
    x, y,
    z,
    alpha, beta,
    M,
    N,
    stride_xm, stride_xk,
    stride_ym, stride_yn,
    stride_zm, stride_zn,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
    ):
    """
    calculate y = alpha * y + beta * x @ y, where x is a symetric matrix, y is a normal matrix, alpha & beta are scalars
    """
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_yn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_M)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    ktile = 0
    for k in range(0, pid_m):
    a_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xm[None, :] * stride_xk)
    b_ptrs = y + ((offs_k[:, None] + ktile) * stride_ym + offs_yn[None, :] * stride_yn)
    # print(ktile)
    a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0)
    b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0)
    accumulator = tl.dot(tl.permute(a, (1, 0)), b, accumulator)
    ktile += BLOCK_SIZE_M

    for k in range(pid_m, tl.cdiv(M, BLOCK_SIZE_M)):
    a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk)
    b_ptrs = y + ((offs_k[:, None] + ktile) * stride_ym + offs_yn[None, :] * stride_yn)
    # print(ktile)
    a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0)
    b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0)
    accumulator = tl.dot(a, b, accumulator)
    ktile += BLOCK_SIZE_M

    # use dtype.element_ty to accomodate different input datatypes as in cpp templates
    # https://github.com/triton-lang/triton/issues/2252
    c = accumulator.to(x.dtype.element_ty)

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

    y_tile_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
    c_ptrs = z + stride_zm * offs_cm[:, None] + stride_zn * offs_cn[None, :]

    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    a_tile = tl.load(y_tile_ptrs, mask=c_mask, other=0.0)
    c = alpha * a_tile + beta * c
    tl.store(c_ptrs, c, mask=c_mask)


    def fast_ns_iter(X, a=3.4445, b=-4.7750, c=2.0315):
    X = X.contiguous()
    M, K = X.shape
    A = torch.empty((M, M), device=X.device, dtype=X.dtype)
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(M, META['BLOCK_SIZE_M']), )
    mmt_kernel[grid](
    X,
    A,
    M,
    K,
    X.stride(0),
    X.stride(1),
    A.stride(0),
    A.stride(1)
    )
    BLOCK_SIZE_M = mmt_kernel.best_config.kwargs['BLOCK_SIZE_M']

    B = torch.empty((M, M), device=X.device, dtype=X.dtype)

    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(M, META['BLOCK_SIZE_M']), )
    sym_axpbxx_kernel[grid](
    A,
    B,
    b,
    c,
    M,
    A.stride(0),
    A.stride(1),
    B.stride(0),
    B.stride(1),
    BLOCK_SIZE_M=BLOCK_SIZE_M
    )
    N = K # TODO: rename
    X_ = torch.empty((M, K), device=X.device, dtype=X.dtype)
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    sym_aypbxy_kernel[grid](
    B,
    X,
    X_,
    a,
    1.0,
    M,
    N,
    B.stride(0),
    B.stride(1),
    X.stride(0),
    X.stride(1),
    X_.stride(0),
    X_.stride(1),
    BLOCK_SIZE_M=BLOCK_SIZE_M
    )

    return A, B, X_

    def ref_ns_iter(X, a=3.4445, b=-4.7750, c=2.0315):
    A = X @ X.T
    B = b * A + c * A @ A
    X = a * X + B @ X
    return A, B, X

    x = torch.randn(512, 512).cuda().half()/100

    A, B, X_ = fast_ns_iter(x)

    A_ref, B_ref, X_ref = ref_ns_iter(x)

    for (res, ref, name) in [(A, A_ref, 'A'), (B, B_ref, 'B'), (X_, X_ref, 'X')]:
    mask = torch.isclose(res, ref, rtol=1e-2, atol=1e-2)
    if name != 'X':
    size = mask.size(0)
    mask |= torch.tril(torch.ones(size, size, dtype=torch.bool, device=mask.device))
    assert torch.all(mask), f"Results not match for {name}"
    print(f"Results match for Tensor {name}")


    def newtonschulz_base(G: Tensor, steps: int = 5) -> Tensor:
    assert G.ndim >= 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()
    if G.size(-2) > G.size(-1):
    X = X.mT

    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
    A = X @ X.T
    B = b * A + c * A @ A
    X = a * X + B @ X

    if G.size(-2) > G.size(-1):
    X = X.mT
    return X



    def fast_newtonschulz_v2(G: Tensor, steps: int = 5) -> Tensor:
    assert G.ndim >= 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()
    if G.size(-2) > G.size(-1):
    X = X.mT

    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
    _, _, X = fast_ns_iter(X, a, b, c)

    if G.size(-2) > G.size(-1):
    X = X.mT
    return X


    for N in [1024, 2048, 4096, 8192]:
    x = torch.randn(N, N, device='cuda')
    base = triton.testing.do_bench(lambda: newtonschulz_base(x))
    flash_v1 = triton.testing.do_bench(lambda: fast_newtonschulz_v1(x, steps=5))
    flash_v2 = triton.testing.do_bench(lambda: fast_newtonschulz_v2(x))
    print(f"Dimension: {N:<5} | Torch: {base:>10.3f} ms | Flash V1: {flash_v1:>10.3f} ms| Flash V2: {flash_v2:>10.3f} ms")


    """Example output:
    Current CUDA device: NVIDIA A100-PCIE-40GB
    The scripts takes abit long to run (autotuning for triton kernels). Set TRITON_PRINT_AUTOTUNING=1 to make autotuning verbal.
    Results match for Tensor A
    Results match for Tensor B
    Results match for Tensor X
    Dimension: 1024 | Torch: 0.703 ms | Flash V1: 1.355 ms| Flash V2: 1.090 ms
    Dimension: 2048 | Torch: 2.272 ms | Flash V1: 1.611 ms| Flash V2: 2.507 ms
    Dimension: 4096 | Torch: 10.946 ms | Flash V1: 8.413 ms| Flash V2: 8.147 ms
    Dimension: 8192 | Torch: 79.655 ms | Flash V1: 59.229 ms| Flash V2: 62.253 ms
    """