Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active September 19, 2025 09:51
Show Gist options
  • Save Chillee/e2b07157caeade8c6b0bdf463d10f833 to your computer and use it in GitHub Desktop.
Save Chillee/e2b07157caeade8c6b0bdf463d10f833 to your computer and use it in GitHub Desktop.

Revisions

  1. Chillee revised this gist Jul 29, 2025. 1 changed file with 2 additions and 2 deletions.
    4 changes: 2 additions & 2 deletions tv_layout_viz.py
    Original file line number Diff line number Diff line change
    @@ -110,7 +110,7 @@ def g():

    tiler_mn = (8, 8)
    tv = (
    ((2, 2, 2), (2, 2, 2)), # thr_shape / thr_stride
    ((1, 16, 4), (8, 2, 32)), # val_shape / val_stride
    ((2, 2, 2), (2, 2, 2)), # thr_shape / val_shape
    ((1, 16, 4), (8, 2, 32)), # thr_stride / val_stride
    )
    visualize_tv_layout(tiler_mn, tv)
  2. Chillee revised this gist Jul 29, 2025. 1 changed file with 26 additions and 20 deletions.
    46 changes: 26 additions & 20 deletions tv_layout_viz.py
    Original file line number Diff line number Diff line change
    @@ -1,17 +1,15 @@
    import math
    import cutlass.cute as cute
    import cutlass

    # ────────────────────────────────────────────────────────────────────
    # core helper
    # ────────────────────────────────────────────────────────────────────
    def visualize_tv_layout(
    tiler_mn: tuple[int, int],
    tv_layout, # (((thr_shape),(val_shape)),
    # ((thr_stride),(val_stride)))
    *,
    font_size: int = 32,
    cell_px: int = 200,
    grid_lw: float = 2.5,
    font_size: int = 10,
    cell_px: int = 70,
    grid_lw: float = 1.5,
    color_fn=None, # optional (tid,vid) -> colour
    ):
    """Draw a T/V checkerboard for an arbitrary TV layout."""
    @@ -42,14 +40,20 @@ def visualize_tv_layout(
    # 2) Query CuTe for every (tid, vid) → (m,n)
    # -----------------------------------------------------------------

    @cute.jit
    def g():
    tv_layout = cute.make_layout(shape, stride=stride)
    tid_vals = []
    for tid in cutlass.range_constexpr(n_thr):
    vid_vals = []
    for vid in cutlass.range_constexpr(n_val):
    vid_vals.append(tv_layout((tid, vid)))
    tid_vals.append(vid_vals)
    return tid_vals
    vals = g()
    for tid in range(n_thr):
    for vid in range(n_val):
    @cute.jit
    def g():
    tv_layout = cute.make_layout(shape, stride=stride)
    out = tv_layout((tid, vid))
    return out
    pos = g()
    pos = vals[tid][vid]
    n = pos // M
    m = pos % M
    if filled[m, n]:
    @@ -89,22 +93,24 @@ def g():
    fontsize=font_size, weight="bold"
    )

    ax.set_xticks(np.arange(N + 1) - .5, minor=True)
    ax.set_yticks(np.arange(M + 1) - .5, minor=True)
    ax.grid(which="minor", color="black", linewidth=grid_lw)
    ax.tick_params(which="minor", bottom=False, left=False)
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_xticks(np.arange(N + 1) - 0.5)
    ax.set_yticks(np.arange(M + 1) - 0.5)
    ax.set_xticklabels([str(i) for i in range(N + 1)]) # Show x tick labels
    ax.set_yticklabels([str(i) for i in range(M + 1)]) # Show y tick labels
    ax.tick_params(axis='both', which='both', length=6, width=1) # Make ticks more visible
    ax.tick_params(axis='x', which='both', top=True, bottom=False, labeltop=True, labelbottom=False) # Show ticks and labels on top
    ax.tick_params(axis='y', which='both', left=True, right=False) # Show ticks on left
    ax.grid(which="major", color="black", linewidth=grid_lw)
    ax.set_xlim(-.5, N -.5); ax.set_ylim(M -.5, -.5)

    ax.set_title(f"tv_layout = {tv_layout}", fontsize=font_size + 2, pad=12)
    plt.tight_layout()
    plt.show()
    plt.savefig("tv_layout.svg")


    tiler_mn = (8, 8)
    tv = (
    ((2, 2, 2), (2, 2, 2)), # thr_shape / thr_stride
    ((1, 16, 4), (8, 2, 32)), # val_shape / val_stride
    ((2, 2, 2), (2, 2, 2)), # thr_shape / thr_stride
    ((1, 16, 4), (8, 2, 32)), # val_shape / val_stride
    )
    visualize_tv_layout(tiler_mn, tv)
  3. Chillee created this gist Jul 28, 2025.
    110 changes: 110 additions & 0 deletions tv_layout_viz.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,110 @@
    import math
    import cutlass.cute as cute

    # ────────────────────────────────────────────────────────────────────
    # core helper
    # ────────────────────────────────────────────────────────────────────
    def visualize_tv_layout(
    tiler_mn: tuple[int, int],
    tv_layout, # (((thr_shape),(val_shape)),
    # ((thr_stride),(val_stride)))
    *,
    font_size: int = 32,
    cell_px: int = 200,
    grid_lw: float = 2.5,
    color_fn=None, # optional (tid,vid) -> colour
    ):
    """Draw a T/V checkerboard for an arbitrary TV layout."""
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.colors as mcolors

    # -----------------------------------------------------------------
    # 1) Build a real CuTe layout from the tuple the user passed
    # -----------------------------------------------------------------
    shape, stride = tv_layout

    if isinstance(shape[0], int):
    n_thr = shape[0]
    else:
    n_thr = math.prod(shape[0])
    if isinstance(shape[1], int):
    n_val = shape[1]
    else:
    n_val = math.prod(shape[1])
    M, N = tiler_mn

    thr_ids = np.full((M, N), -1, dtype=int)
    val_ids = np.full((M, N), -1, dtype=int)
    filled = np.zeros((M, N), dtype=bool)

    # -----------------------------------------------------------------
    # 2) Query CuTe for every (tid, vid) → (m,n)
    # -----------------------------------------------------------------

    for tid in range(n_thr):
    for vid in range(n_val):
    @cute.jit
    def g():
    tv_layout = cute.make_layout(shape, stride=stride)
    out = tv_layout((tid, vid))
    return out
    pos = g()
    n = pos // M
    m = pos % M
    if filled[m, n]:
    continue
    thr_ids[m, n] = tid
    val_ids[m, n] = vid
    filled[m, n] = True

    # -----------------------------------------------------------------
    # 3) Colours (default: pastel per-thread)
    # -----------------------------------------------------------------
    if color_fn is None:
    pastel = plt.cm.Set3.colors
    cmap = (pastel * ((n_thr // 12) + 1))[:n_thr]
    color_fn = lambda t, v: cmap[t % len(cmap)]

    bg_rgb = np.zeros((M, N, 3))
    for m in range(M):
    for n in range(N):
    tid = thr_ids[m, n]
    if tid >= 0:
    bg_rgb[m, n] = mcolors.to_rgb(color_fn(tid, val_ids[m, n]))

    # -----------------------------------------------------------------
    # 4) Draw
    # -----------------------------------------------------------------
    fig_w, fig_h = N * cell_px / 100, M * cell_px / 100
    fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=100)
    ax.imshow(bg_rgb, interpolation="none")

    for m in range(M):
    for n in range(N):
    if thr_ids[m, n] >= 0:
    ax.text(
    n, m, f"T{thr_ids[m,n]}\nV{val_ids[m,n]}",
    ha="center", va="center",
    fontsize=font_size, weight="bold"
    )

    ax.set_xticks(np.arange(N + 1) - .5, minor=True)
    ax.set_yticks(np.arange(M + 1) - .5, minor=True)
    ax.grid(which="minor", color="black", linewidth=grid_lw)
    ax.tick_params(which="minor", bottom=False, left=False)
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_xlim(-.5, N -.5); ax.set_ylim(M -.5, -.5)

    ax.set_title(f"tv_layout = {tv_layout}", fontsize=font_size + 2, pad=12)
    plt.tight_layout()
    plt.show()
    plt.savefig("tv_layout.svg")


    tiler_mn = (8, 8)
    tv = (
    ((2, 2, 2), (2, 2, 2)), # thr_shape / thr_stride
    ((1, 16, 4), (8, 2, 32)), # val_shape / val_stride
    )
    visualize_tv_layout(tiler_mn, tv)