Last active
September 19, 2025 09:51
-
-
Save Chillee/e2b07157caeade8c6b0bdf463d10f833 to your computer and use it in GitHub Desktop.
Revisions
-
Chillee revised this gist
Jul 29, 2025 . 1 changed file with 2 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -110,7 +110,7 @@ def g(): tiler_mn = (8, 8) tv = ( ((2, 2, 2), (2, 2, 2)), # thr_shape / val_shape ((1, 16, 4), (8, 2, 32)), # thr_stride / val_stride ) visualize_tv_layout(tiler_mn, tv) -
Chillee revised this gist
Jul 29, 2025 . 1 changed file with 26 additions and 20 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,17 +1,15 @@ import math import cutlass.cute as cute import cutlass def visualize_tv_layout( tiler_mn: tuple[int, int], tv_layout, # (((thr_shape),(val_shape)), # ((thr_stride),(val_stride))) *, font_size: int = 10, cell_px: int = 70, grid_lw: float = 1.5, color_fn=None, # optional (tid,vid) -> colour ): """Draw a T/V checkerboard for an arbitrary TV layout.""" @@ -42,14 +40,20 @@ def visualize_tv_layout( # 2) Query CuTe for every (tid, vid) → (m,n) # ----------------------------------------------------------------- @cute.jit def g(): tv_layout = cute.make_layout(shape, stride=stride) tid_vals = [] for tid in cutlass.range_constexpr(n_thr): vid_vals = [] for vid in cutlass.range_constexpr(n_val): vid_vals.append(tv_layout((tid, vid))) tid_vals.append(vid_vals) return tid_vals vals = g() for tid in range(n_thr): for vid in range(n_val): pos = vals[tid][vid] n = pos // M m = pos % M if filled[m, n]: @@ -89,22 +93,24 @@ def g(): fontsize=font_size, weight="bold" ) ax.set_xticks(np.arange(N + 1) - 0.5) ax.set_yticks(np.arange(M + 1) - 0.5) ax.set_xticklabels([str(i) for i in range(N + 1)]) # Show x tick labels ax.set_yticklabels([str(i) for i in range(M + 1)]) # Show y tick labels ax.tick_params(axis='both', which='both', length=6, width=1) # Make ticks more visible ax.tick_params(axis='x', which='both', top=True, bottom=False, labeltop=True, labelbottom=False) # Show ticks and labels on top ax.tick_params(axis='y', which='both', left=True, right=False) # Show ticks on left ax.grid(which="major", color="black", linewidth=grid_lw) ax.set_xlim(-.5, N -.5); ax.set_ylim(M -.5, -.5) ax.set_title(f"tv_layout = {tv_layout}", fontsize=font_size + 2, pad=12) plt.tight_layout() plt.savefig("tv_layout.svg") tiler_mn = (8, 8) tv = ( ((2, 2, 2), (2, 2, 2)), # thr_shape / thr_stride ((1, 16, 4), (8, 2, 32)), # val_shape / val_stride ) visualize_tv_layout(tiler_mn, tv) -
Chillee created this gist
Jul 28, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,110 @@ import math import cutlass.cute as cute # ──────────────────────────────────────────────────────────────────── # core helper # ──────────────────────────────────────────────────────────────────── def visualize_tv_layout( tiler_mn: tuple[int, int], tv_layout, # (((thr_shape),(val_shape)), # ((thr_stride),(val_stride))) *, font_size: int = 32, cell_px: int = 200, grid_lw: float = 2.5, color_fn=None, # optional (tid,vid) -> colour ): """Draw a T/V checkerboard for an arbitrary TV layout.""" import numpy as np import matplotlib.pyplot as plt import matplotlib.colors as mcolors # ----------------------------------------------------------------- # 1) Build a real CuTe layout from the tuple the user passed # ----------------------------------------------------------------- shape, stride = tv_layout if isinstance(shape[0], int): n_thr = shape[0] else: n_thr = math.prod(shape[0]) if isinstance(shape[1], int): n_val = shape[1] else: n_val = math.prod(shape[1]) M, N = tiler_mn thr_ids = np.full((M, N), -1, dtype=int) val_ids = np.full((M, N), -1, dtype=int) filled = np.zeros((M, N), dtype=bool) # ----------------------------------------------------------------- # 2) Query CuTe for every (tid, vid) → (m,n) # ----------------------------------------------------------------- for tid in range(n_thr): for vid in range(n_val): @cute.jit def g(): tv_layout = cute.make_layout(shape, stride=stride) out = tv_layout((tid, vid)) return out pos = g() n = pos // M m = pos % M if filled[m, n]: continue thr_ids[m, n] = tid val_ids[m, n] = vid filled[m, n] = True # ----------------------------------------------------------------- # 3) Colours (default: pastel per-thread) # ----------------------------------------------------------------- if color_fn is None: pastel = plt.cm.Set3.colors cmap = (pastel * ((n_thr // 12) + 1))[:n_thr] color_fn = lambda t, v: cmap[t % len(cmap)] bg_rgb = np.zeros((M, N, 3)) for m in range(M): for n in range(N): tid = thr_ids[m, n] if tid >= 0: bg_rgb[m, n] = mcolors.to_rgb(color_fn(tid, val_ids[m, n])) # ----------------------------------------------------------------- # 4) Draw # ----------------------------------------------------------------- fig_w, fig_h = N * cell_px / 100, M * cell_px / 100 fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=100) ax.imshow(bg_rgb, interpolation="none") for m in range(M): for n in range(N): if thr_ids[m, n] >= 0: ax.text( n, m, f"T{thr_ids[m,n]}\nV{val_ids[m,n]}", ha="center", va="center", fontsize=font_size, weight="bold" ) ax.set_xticks(np.arange(N + 1) - .5, minor=True) ax.set_yticks(np.arange(M + 1) - .5, minor=True) ax.grid(which="minor", color="black", linewidth=grid_lw) ax.tick_params(which="minor", bottom=False, left=False) ax.set_xticks([]); ax.set_yticks([]) ax.set_xlim(-.5, N -.5); ax.set_ylim(M -.5, -.5) ax.set_title(f"tv_layout = {tv_layout}", fontsize=font_size + 2, pad=12) plt.tight_layout() plt.show() plt.savefig("tv_layout.svg") tiler_mn = (8, 8) tv = ( ((2, 2, 2), (2, 2, 2)), # thr_shape / thr_stride ((1, 16, 4), (8, 2, 32)), # val_shape / val_stride ) visualize_tv_layout(tiler_mn, tv)