Skip to content

Instantly share code, notes, and snippets.

@jeromeku
Forked from fengxie/vector_as_kernel_arg.py
Created September 14, 2025 17:12
Show Gist options
  • Save jeromeku/2a2373b97f4b505e535dfc7abf397b47 to your computer and use it in GitHub Desktop.
Save jeromeku/2a2373b97f4b505e535dfc7abf397b47 to your computer and use it in GitHub Desktop.

Revisions

  1. @fengxie fengxie revised this gist Sep 10, 2025. 1 changed file with 3 additions and 0 deletions.
    3 changes: 3 additions & 0 deletions vector_as_kernel_arg.py
    Original file line number Diff line number Diff line change
    @@ -11,8 +11,11 @@ def kernel_use_vec_as_arg(vec, res: cute.Tensor):

    @cute.jit
    def vector_as_kernel_arg(res: cute.Tensor):
    # Create an array/vector on CPU
    vec = cute.make_fragment(10, dtype=cutlass.Float32)
    vec.fill(1.0)

    # Pass array/vector to kernel as argument without explicit copy from host to device
    kernel_use_vec_as_arg(vec.load(), res).launch(grid=[1, 1, 1], block=[1, 1, 1])


  2. @fengxie fengxie revised this gist Sep 10, 2025. 1 changed file with 0 additions and 1 deletion.
    1 change: 0 additions & 1 deletion vector_as_kernel_arg.py
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,3 @@
    import cutlass

    import cutlass
    import cutlass.cute as cute
  3. @fengxie fengxie renamed this gist Sep 10, 2025. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  4. @fengxie fengxie created this gist Sep 10, 2025.
    25 changes: 25 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,25 @@
    import cutlass

    import cutlass
    import cutlass.cute as cute
    from cutlass.cute.runtime import from_dlpack


    @cute.kernel
    def kernel_use_vec_as_arg(vec, res: cute.Tensor):
    # cute.print_tensor(vec)
    res.store(vec)

    @cute.jit
    def vector_as_kernel_arg(res: cute.Tensor):
    vec = cute.make_fragment(10, dtype=cutlass.Float32)
    vec.fill(1.0)
    kernel_use_vec_as_arg(vec.load(), res).launch(grid=[1, 1, 1], block=[1, 1, 1])


    import torch

    res = torch.zeros(10, dtype=torch.float32, device="cuda")
    vector_as_kernel_arg(from_dlpack(res))
    torch.cuda.synchronize()
    print(res)