#!/usr/bin/env python # coding: utf-8 import pandas as pd import triton import triton.language as tl import torch from torch import Tensor try: from flash_muon import fast_newtonschulz as fast_newtonschulz_v1 except ImportError as e: print("Failed to import fast_newtonschulz from flash_muon. Please ensure the module is installed:") print("\tgit clone https://github.com/nil0x9/flash-muon.git && pip install -e flash-muon/") import sys sys.exit(1) assert triton.__version__ >= '3.2.0', "This scripts requires triton version >= 3.2.0 to run." assert torch.cuda.is_available(), "Need CUDA device to run!" current_device = torch.cuda.current_device() device_name = torch.cuda.get_device_name(current_device) print(f"Current CUDA device: {device_name}") print("The scripts takes abit long to run (autotuning for triton kernels). Set TRITON_PRINT_AUTOTUNING=1 to make autotuning verbal.") def get_mmt_kernel_autotune_config(): return [triton.Config({'BLOCK_SIZE_M': blk_m, 'BLOCK_SIZE_K': blk_k, 'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps) for blk_m in [32, 64, 128] for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] for n_warps in [2, 4, 8] ] def get_sym_axpbxx_kernel_autotune_config(): return [triton.Config({'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps) for grp_sz in [4, 8] for n_stages in [1, 2, 3, 4, 5] for n_warps in [1, 2, 4, 8] ] def get_sym_aypbxy_kernel_autotune_config(): return [triton.Config({'BLOCK_SIZE_N': blk_n, 'GROUP_SIZE_M': grp_sz}, num_stages=n_stages, num_warps=n_warps) for blk_n in [32, 64, 128] for grp_sz in [4, 8] for n_stages in [1, 2, 3, 4, 5] for n_warps in [1, 2, 4, 8] ] @triton.autotune( configs=get_mmt_kernel_autotune_config(), key=['M', 'K'], ) @triton.jit def mmt_kernel( x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr ): """ Core kernel jit function of matmul_transpose that computes y = x @ x.T The code is a simple adaptation from the triton `matmul` tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html """ pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m if pid_m > pid_n: return offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_k = tl.arange(0, BLOCK_SIZE_K) # we use a & b ptrs to denote different rows of x. a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) a_ptrs += BLOCK_SIZE_K * stride_xk b_ptrs += BLOCK_SIZE_K * stride_xk # use dtype.element_ty to accomodate different input datatypes as in cpp templates # https://github.com/triton-lang/triton/issues/2252 c = accumulator.to(x.dtype.element_ty, fp_downcast_rounding='rtne') offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) tl.store(c_ptrs, c, mask=c_mask) @triton.autotune( configs=get_sym_axpbxx_kernel_autotune_config(), key=['M'], ) @triton.jit def sym_axpbxx_kernel( x, y, alpha, beta, M, stride_xm, stride_xk, stride_ym, stride_yn, BLOCK_SIZE_M: tl.constexpr, GROUP_SIZE_M: tl.constexpr ): """ calculate y = alpha * x + beta * x @ x, where x is symetric matrix, alpha & beta are scalars """ pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m if pid_m > pid_n: return offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_k = tl.arange(0, BLOCK_SIZE_M) # we use a & b ptrs to denote different rows of x. # a_ptrs_base = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) # b_ptrs_base = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) ktile = 0 for k in range(0, pid_m): a_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xm[None, :] * stride_xk) b_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xn[None, :] * stride_xk) # print(ktile) a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) b = tl.load(b_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) accumulator = tl.dot(tl.permute(a, (1, 0)), b, accumulator) ktile += BLOCK_SIZE_M for k in range(pid_m, pid_n+1): a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk) b_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xn[None, :] * stride_xk) # print(ktile) a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) b = tl.load(b_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) accumulator = tl.dot(a, b, accumulator) ktile += BLOCK_SIZE_M for k in range(pid_n+1, tl.cdiv(M, BLOCK_SIZE_M)): a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk) b_ptrs = x + (offs_xn[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk) # print(ktile) a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0) accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) ktile += BLOCK_SIZE_M # use dtype.element_ty to accomodate different input datatypes as in cpp templates # https://github.com/triton-lang/triton/issues/2252 c = accumulator.to(x.dtype.element_ty, fp_downcast_rounding='rtne') offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) a_tile_ptrs = x + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) a_tile = tl.load(a_tile_ptrs, mask=c_mask, other=0.0) c = alpha * a_tile + beta * c tl.store(c_ptrs, c, mask=c_mask) @triton.autotune( configs=get_sym_aypbxy_kernel_autotune_config(), key=['M', 'N'], ) @triton.jit def sym_aypbxy_kernel( x, y, z, alpha, beta, M, N, stride_xm, stride_xk, stride_ym, stride_yn, stride_zm, stride_zn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr ): """ calculate y = alpha * y + beta * x @ y, where x is a symetric matrix, y is a normal matrix, alpha & beta are scalars """ pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_yn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_M) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) ktile = 0 for k in range(0, pid_m): a_ptrs = x + ((offs_k[:, None] + ktile) * stride_xm + offs_xm[None, :] * stride_xk) b_ptrs = y + ((offs_k[:, None] + ktile) * stride_ym + offs_yn[None, :] * stride_yn) # print(ktile) a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0) accumulator = tl.dot(tl.permute(a, (1, 0)), b, accumulator) ktile += BLOCK_SIZE_M for k in range(pid_m, tl.cdiv(M, BLOCK_SIZE_M)): a_ptrs = x + (offs_xm[:, None] * stride_xm + (offs_k[None, :] + ktile) * stride_xk) b_ptrs = y + ((offs_k[:, None] + ktile) * stride_ym + offs_yn[None, :] * stride_yn) # print(ktile) a = tl.load(a_ptrs, mask=offs_k[None, :] < M - ktile, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < M - ktile, other=0.0) accumulator = tl.dot(a, b, accumulator) ktile += BLOCK_SIZE_M # use dtype.element_ty to accomodate different input datatypes as in cpp templates # https://github.com/triton-lang/triton/issues/2252 c = accumulator.to(x.dtype.element_ty) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) y_tile_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] c_ptrs = z + stride_zm * offs_cm[:, None] + stride_zn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) a_tile = tl.load(y_tile_ptrs, mask=c_mask, other=0.0) c = alpha * a_tile + beta * c tl.store(c_ptrs, c, mask=c_mask) def fast_ns_iter(X, a=3.4445, b=-4.7750, c=2.0315): X = X.contiguous() M, K = X.shape A = torch.empty((M, M), device=X.device, dtype=X.dtype) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(M, META['BLOCK_SIZE_M']), ) mmt_kernel[grid]( X, A, M, K, X.stride(0), X.stride(1), A.stride(0), A.stride(1) ) BLOCK_SIZE_M = mmt_kernel.best_config.kwargs['BLOCK_SIZE_M'] B = torch.empty((M, M), device=X.device, dtype=X.dtype) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(M, META['BLOCK_SIZE_M']), ) sym_axpbxx_kernel[grid]( A, B, b, c, M, A.stride(0), A.stride(1), B.stride(0), B.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M ) N = K # TODO: rename X_ = torch.empty((M, K), device=X.device, dtype=X.dtype) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) sym_aypbxy_kernel[grid]( B, X, X_, a, 1.0, M, N, B.stride(0), B.stride(1), X.stride(0), X.stride(1), X_.stride(0), X_.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M ) return A, B, X_ def ref_ns_iter(X, a=3.4445, b=-4.7750, c=2.0315): A = X @ X.T B = b * A + c * A @ A X = a * X + B @ X return A, B, X x = torch.randn(512, 512).cuda().half()/100 A, B, X_ = fast_ns_iter(x) A_ref, B_ref, X_ref = ref_ns_iter(x) for (res, ref, name) in [(A, A_ref, 'A'), (B, B_ref, 'B'), (X_, X_ref, 'X')]: mask = torch.isclose(res, ref, rtol=1e-2, atol=1e-2) if name != 'X': size = mask.size(0) mask |= torch.tril(torch.ones(size, size, dtype=torch.bool, device=mask.device)) assert torch.all(mask), f"Results not match for {name}" print(f"Results match for Tensor {name}") def newtonschulz_base(G: Tensor, steps: int = 5) -> Tensor: assert G.ndim >= 2 a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() if G.size(-2) > G.size(-1): X = X.mT # Ensure spectral norm is at most 1 X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) # Perform the NS iterations for _ in range(steps): A = X @ X.T B = b * A + c * A @ A X = a * X + B @ X if G.size(-2) > G.size(-1): X = X.mT return X def fast_newtonschulz_v2(G: Tensor, steps: int = 5) -> Tensor: assert G.ndim >= 2 a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() if G.size(-2) > G.size(-1): X = X.mT # Ensure spectral norm is at most 1 X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) # Perform the NS iterations for _ in range(steps): _, _, X = fast_ns_iter(X, a, b, c) if G.size(-2) > G.size(-1): X = X.mT return X for N in [1024, 2048, 4096, 8192]: x = torch.randn(N, N, device='cuda') base = triton.testing.do_bench(lambda: newtonschulz_base(x)) flash_v1 = triton.testing.do_bench(lambda: fast_newtonschulz_v1(x, steps=5)) flash_v2 = triton.testing.do_bench(lambda: fast_newtonschulz_v2(x)) print(f"Dimension: {N:<5} | Torch: {base:>10.3f} ms | Flash V1: {flash_v1:>10.3f} ms| Flash V2: {flash_v2:>10.3f} ms") """Example output: Current CUDA device: NVIDIA A100-PCIE-40GB The scripts takes abit long to run (autotuning for triton kernels). Set TRITON_PRINT_AUTOTUNING=1 to make autotuning verbal. Results match for Tensor A Results match for Tensor B Results match for Tensor X Dimension: 1024 | Torch: 0.703 ms | Flash V1: 1.355 ms| Flash V2: 1.090 ms Dimension: 2048 | Torch: 2.272 ms | Flash V1: 1.611 ms| Flash V2: 2.507 ms Dimension: 4096 | Torch: 10.946 ms | Flash V1: 8.413 ms| Flash V2: 8.147 ms Dimension: 8192 | Torch: 79.655 ms | Flash V1: 59.229 ms| Flash V2: 62.253 ms """