Skip to content

Instantly share code, notes, and snippets.

@fvarno
Created February 12, 2024 16:25
Show Gist options
  • Select an option

  • Save fvarno/d90655aa6087c57ccfc49f16c1f08bee to your computer and use it in GitHub Desktop.

Select an option

Save fvarno/d90655aa6087c57ccfc49f16c1f08bee to your computer and use it in GitHub Desktop.
GeMM with tiling
import torch
# multiple a MxN matrix with a NxK matrix
M, N, K = 20,10,30
A = torch.randn(M,N)
B = torch.randn(N,K)
untiled_res = torch.matmul(A, B)
tile_size=5
A_ = A.reshape(M, 1, N//tile_size, 1, tile_size)
B_ = B.reshape(N//tile_size, tile_size, K//tile_size, tile_size).permute(2,0,1,3)
tiled_res = torch.matmul(A_, B_).sum(dim=2)
print(torch.allclose(untiled_res, tiled_res.reshape(M,K)))
print((untiled_res - tiled_res.reshape(M,K)).abs().max())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment