Created
February 12, 2024 16:25
-
-
Save fvarno/d90655aa6087c57ccfc49f16c1f08bee to your computer and use it in GitHub Desktop.
Revisions
-
fvarno created this gist
Feb 12, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,13 @@ import torch # multiple a MxN matrix with a NxK matrix M, N, K = 20,10,30 A = torch.randn(M,N) B = torch.randn(N,K) untiled_res = torch.matmul(A, B) tile_size=5 A_ = A.reshape(M, 1, N//tile_size, 1, tile_size) B_ = B.reshape(N//tile_size, tile_size, K//tile_size, tile_size).permute(2,0,1,3) tiled_res = torch.matmul(A_, B_).sum(dim=2) print(torch.allclose(untiled_res, tiled_res.reshape(M,K))) print((untiled_res - tiled_res.reshape(M,K)).abs().max())