Last active
May 20, 2025 08:37
-
-
Save nil0x9/16956ab4b66fcfbf9f81a174fef6bf71 to your computer and use it in GitHub Desktop.
Revisions
-
nil0x9 revised this gist
May 20, 2025 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -9,7 +9,7 @@ from flash_muon import fast_newtonschulz as fast_newtonschulz_v1 except ImportError as e: print("Failed to import fast_newtonschulz from flash_muon. Please ensure the module is installed:") print("\tgit clone https://github.com/nil0x9/flash-muon.git && pip install -e flash-muon/") import sys sys.exit(1) -
nil0x9 created this gist
May 8, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,388 @@ #!/usr/bin/env python # coding: utf-8 import pandas as pd import triton import triton.language as tl import torch from torch import Tensor try: from flash_muon import fast_newtonschulz as fast_newtonschulz_v1 except ImportError as e: print("Failed to import fast_newtonschulz from flash_muon. Please ensure the module is installed:") print("\tgit clone https://github.com/nil0x9/flash-muon.git && pip install -e ./") import sys sys.exit(1) assert triton.__version__ >= '3.2.0', "This scripts requires triton version >= 3.2.0 to run." assert torch.cuda.is_available(), "Need CUDA device to run!" current_device = torch.cuda.current_device() device_name = torch.cuda.get_device_name(current_device) print(f"Current CUDA device: {device_name}") print("The scripts takes abit long to run (autotuning for triton kernels). Set TRITON_PRINT_AUTOTUNING=1 to make autotuning verbal.") def get_mmt_kernel_autotune_config(): return [triton.Config({'BLOCK_SIZE_M': blk_m, 'BLOCK_SIZE_K': blk_k, 'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps) for blk_m in [32, 64, 128] for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] for n_warps in [2, 4, 8] ] def get_sym_axpbxx_kernel_autotune_config(): return [triton.Config({'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps) for grp_sz in [4, 8] for n_stages in [1, 2, 3, 4, 5] for n_warps in [1, 2, 4, 8] ] def get_sym_aypbxy_kernel_autotune_config(): return [triton.Config({'BLOCK_SIZE_N': blk_n, 'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps) for blk_n in [32, 64, 128] for grp_sz in [4, 8] for n_stages in [1, 2, 3, 4, 5] for n_warps in [1, 2, 4, 8] ] @triton.autotune( configs=get_mmt_kernel_autotune_config(), key=['M', 'K'], ) @triton.jit def mmt_kernel( x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr ): """ Core kernel jit function of matmul_transpose that computes y = x @ x.T The code is a simple adaptation from the triton `matmul` tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html """ pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m if pid_m > pid_n: return offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_k = tl.arange(0, BLOCK_SIZE_K) # we use a & b ptrs to denote different rows of x. a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) a_ptrs += BLOCK_SIZE_K * stride_xk b_ptrs += BLOCK_SIZE_K * stride_xk # use dtype.element_ty to accomodate different input datatypes as in cpp templates # https://github.com/triton-lang/triton/issues/2252 c = accumulator.to(x.dtype.element_ty, fp_downcast_rounding='rtne') offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) tl.store(c_ptrs, c, mask=c_mask) @triton.autotune( configs=get_sym_axpbxx_kernel_autotune_config(), key=['M'], ) @triton.jit def sym_axpbxx_kernel( x, y, alpha, beta, M, stride_xm, stride_xk, stride_ym, stride_yn, BLOCK_SIZE_M: tl.constexpr, GROUP_SIZE_M: tl.constexpr ): """ calculate y = alpha * x + beta * x @ x, where x is symetric matrix, alpha & beta are scalars """ pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m if pid_m > pid_n: return offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_k = tl.arange(0, BLOCK_SIZE_M) # we use a & b ptrs to denote different rows of x. # a_ptrs_base = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) # b_ptrs_base = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) ktile = 0 for k in range(0, pid_m): a_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xm[None, :] * stride_xk) b_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xn[None, :] * stride_xk) # print(ktile) a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) b = tl.load(b_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) accumulator = tl.dot(tl.permute(a, (1, 0)), b, accumulator) ktile += BLOCK_SIZE_M for k in range(pid_m, pid_n+1): a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk) b_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xn[None, :] * stride_xk) # print(ktile) a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) b = tl.load(b_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) accumulator = tl.dot(a, b, accumulator) ktile += BLOCK_SIZE_M for k in range(pid_n+1, tl.cdiv(M, BLOCK_SIZE_M)): a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk) b_ptrs = x + (offs_xn[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk) # print(ktile) a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0) accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) ktile += BLOCK_SIZE_M # use dtype.element_ty to accomodate different input datatypes as in cpp templates # https://github.com/triton-lang/triton/issues/2252 c = accumulator.to(x.dtype.element_ty, fp_downcast_rounding='rtne') offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) a_tile_ptrs = x + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) a_tile = tl.load(a_tile_ptrs, mask=c_mask, other=0.0) c = alpha * a_tile + beta * c tl.store(c_ptrs, c, mask=c_mask) @triton.autotune( configs=get_sym_aypbxy_kernel_autotune_config(), key=['M', 'N'], ) @triton.jit def sym_aypbxy_kernel( x, y, z, alpha, beta, M, N, stride_xm, stride_xk, stride_ym, stride_yn, stride_zm, stride_zn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr ): """ calculate y = alpha * y + beta * x @ y, where x is a symetric matrix, y is a normal matrix, alpha & beta are scalars """ pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_yn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_M) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) ktile = 0 for k in range(0, pid_m): a_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xm[None, :] * stride_xk) b_ptrs = y + ((offs_k[:, None] + ktile) * stride_ym + offs_yn[None, :] * stride_yn) # print(ktile) a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0) accumulator = tl.dot(tl.permute(a, (1, 0)), b, accumulator) ktile += BLOCK_SIZE_M for k in range(pid_m, tl.cdiv(M, BLOCK_SIZE_M)): a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk) b_ptrs = y + ((offs_k[:, None] + ktile) * stride_ym + offs_yn[None, :] * stride_yn) # print(ktile) a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0) accumulator = tl.dot(a, b, accumulator) ktile += BLOCK_SIZE_M # use dtype.element_ty to accomodate different input datatypes as in cpp templates # https://github.com/triton-lang/triton/issues/2252 c = accumulator.to(x.dtype.element_ty) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) y_tile_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] c_ptrs = z + stride_zm * offs_cm[:, None] + stride_zn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) a_tile = tl.load(y_tile_ptrs, mask=c_mask, other=0.0) c = alpha * a_tile + beta * c tl.store(c_ptrs, c, mask=c_mask) def fast_ns_iter(X, a=3.4445, b=-4.7750, c=2.0315): X = X.contiguous() M, K = X.shape A = torch.empty((M, M), device=X.device, dtype=X.dtype) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(M, META['BLOCK_SIZE_M']), ) mmt_kernel[grid]( X, A, M, K, X.stride(0), X.stride(1), A.stride(0), A.stride(1) ) BLOCK_SIZE_M = mmt_kernel.best_config.kwargs['BLOCK_SIZE_M'] B = torch.empty((M, M), device=X.device, dtype=X.dtype) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(M, META['BLOCK_SIZE_M']), ) sym_axpbxx_kernel[grid]( A, B, b, c, M, A.stride(0), A.stride(1), B.stride(0), B.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M ) N = K # TODO: rename X_ = torch.empty((M, K), device=X.device, dtype=X.dtype) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) sym_aypbxy_kernel[grid]( B, X, X_, a, 1.0, M, N, B.stride(0), B.stride(1), X.stride(0), X.stride(1), X_.stride(0), X_.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M ) return A, B, X_ def ref_ns_iter(X, a=3.4445, b=-4.7750, c=2.0315): A = X @ X.T B = b * A + c * A @ A X = a * X + B @ X return A, B, X x = torch.randn(512, 512).cuda().half()/100 A, B, X_ = fast_ns_iter(x) A_ref, B_ref, X_ref = ref_ns_iter(x) for (res, ref, name) in [(A, A_ref, 'A'), (B, B_ref, 'B'), (X_, X_ref, 'X')]: mask = torch.isclose(res, ref, rtol=1e-2, atol=1e-2) if name != 'X': size = mask.size(0) mask |= torch.tril(torch.ones(size, size, dtype=torch.bool, device=mask.device)) assert torch.all(mask), f"Results not match for {name}" print(f"Results match for Tensor {name}") def newtonschulz_base(G: Tensor, steps: int = 5) -> Tensor: assert G.ndim >= 2 a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() if G.size(-2) > G.size(-1): X = X.mT # Ensure spectral norm is at most 1 X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) # Perform the NS iterations for _ in range(steps): A = X @ X.T B = b * A + c * A @ A X = a * X + B @ X if G.size(-2) > G.size(-1): X = X.mT return X def fast_newtonschulz_v2(G: Tensor, steps: int = 5) -> Tensor: assert G.ndim >= 2 a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() if G.size(-2) > G.size(-1): X = X.mT # Ensure spectral norm is at most 1 X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) # Perform the NS iterations for _ in range(steps): _, _, X = fast_ns_iter(X, a, b, c) if G.size(-2) > G.size(-1): X = X.mT return X for N in [1024, 2048, 4096, 8192]: x = torch.randn(N, N, device='cuda') base = triton.testing.do_bench(lambda: newtonschulz_base(x)) flash_v1 = triton.testing.do_bench(lambda: fast_newtonschulz_v1(x, steps=5)) flash_v2 = triton.testing.do_bench(lambda: fast_newtonschulz_v2(x)) print(f"Dimension: {N:<5} | Torch: {base:>10.3f} ms | Flash V1: {flash_v1:>10.3f} ms| Flash V2: {flash_v2:>10.3f} ms") """Example output: Current CUDA device: NVIDIA A100-PCIE-40GB The scripts takes abit long to run (autotuning for triton kernels). Set TRITON_PRINT_AUTOTUNING=1 to make autotuning verbal. Results match for Tensor A Results match for Tensor B Results match for Tensor X Dimension: 1024 | Torch: 0.703 ms | Flash V1: 1.355 ms| Flash V2: 1.090 ms Dimension: 2048 | Torch: 2.272 ms | Flash V1: 1.611 ms| Flash V2: 2.507 ms Dimension: 4096 | Torch: 10.946 ms | Flash V1: 8.413 ms| Flash V2: 8.147 ms Dimension: 8192 | Torch: 79.655 ms | Flash V1: 59.229 ms| Flash V2: 62.253 ms """