Created
February 12, 2024 16:25
-
-
Save fvarno/d90655aa6087c57ccfc49f16c1f08bee to your computer and use it in GitHub Desktop.
GeMM with tiling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| # multiple a MxN matrix with a NxK matrix | |
| M, N, K = 20,10,30 | |
| A = torch.randn(M,N) | |
| B = torch.randn(N,K) | |
| untiled_res = torch.matmul(A, B) | |
| tile_size=5 | |
| A_ = A.reshape(M, 1, N//tile_size, 1, tile_size) | |
| B_ = B.reshape(N//tile_size, tile_size, K//tile_size, tile_size).permute(2,0,1,3) | |
| tiled_res = torch.matmul(A_, B_).sum(dim=2) | |
| print(torch.allclose(untiled_res, tiled_res.reshape(M,K))) | |
| print((untiled_res - tiled_res.reshape(M,K)).abs().max()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment