Skip to content

Instantly share code, notes, and snippets.

@jeromeku
Forked from fengxie/vector_as_kernel_arg.py
Created September 14, 2025 17:12
Show Gist options
  • Save jeromeku/2a2373b97f4b505e535dfc7abf397b47 to your computer and use it in GitHub Desktop.
Save jeromeku/2a2373b97f4b505e535dfc7abf397b47 to your computer and use it in GitHub Desktop.
CuTe DSL passing vector as kernel argument
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@cute.kernel
def kernel_use_vec_as_arg(vec, res: cute.Tensor):
# cute.print_tensor(vec)
res.store(vec)
@cute.jit
def vector_as_kernel_arg(res: cute.Tensor):
# Create an array/vector on CPU
vec = cute.make_fragment(10, dtype=cutlass.Float32)
vec.fill(1.0)
# Pass array/vector to kernel as argument without explicit copy from host to device
kernel_use_vec_as_arg(vec.load(), res).launch(grid=[1, 1, 1], block=[1, 1, 1])
import torch
res = torch.zeros(10, dtype=torch.float32, device="cuda")
vector_as_kernel_arg(from_dlpack(res))
torch.cuda.synchronize()
print(res)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment