Skip to content

Instantly share code, notes, and snippets.

@fvarno
Created February 12, 2024 16:25
Show Gist options
  • Select an option

  • Save fvarno/d90655aa6087c57ccfc49f16c1f08bee to your computer and use it in GitHub Desktop.

Select an option

Save fvarno/d90655aa6087c57ccfc49f16c1f08bee to your computer and use it in GitHub Desktop.

Revisions

  1. fvarno created this gist Feb 12, 2024.
    13 changes: 13 additions & 0 deletions gemm_tiling.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,13 @@
    import torch
    # multiple a MxN matrix with a NxK matrix
    M, N, K = 20,10,30
    A = torch.randn(M,N)
    B = torch.randn(N,K)
    untiled_res = torch.matmul(A, B)
    tile_size=5

    A_ = A.reshape(M, 1, N//tile_size, 1, tile_size)
    B_ = B.reshape(N//tile_size, tile_size, K//tile_size, tile_size).permute(2,0,1,3)
    tiled_res = torch.matmul(A_, B_).sum(dim=2)
    print(torch.allclose(untiled_res, tiled_res.reshape(M,K)))
    print((untiled_res - tiled_res.reshape(M,K)).abs().max())