Skip to content

Instantly share code, notes, and snippets.

@jeromeku
Forked from fengxie/pass_list_as_vector.py
Created September 14, 2025 17:12
Show Gist options
  • Save jeromeku/9a19019b119a3a7dfc5d7ed15a58f985 to your computer and use it in GitHub Desktop.
Save jeromeku/9a19019b119a3a7dfc5d7ed15a58f985 to your computer and use it in GitHub Desktop.

Revisions

  1. @fengxie fengxie created this gist Sep 10, 2025.
    29 changes: 29 additions & 0 deletions pass_list_as_vector.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,29 @@

    from typing import List

    import cutlass
    import cutlass.cute as cute
    from cutlass.cute.runtime import from_dlpack


    @cute.kernel
    def kernel_use_vec_as_arg(vec, res: cute.Tensor):
    res.store(vec)
    for i in cutlass.range(10):
    cute.printf("vec[%d]: %d", i, vec[i])
    # cute.print_tensor(vec)

    @cute.jit
    def pass_list_as_vector(xs: List[cutlass.Int32], res: cute.Tensor):
    vec = cute.make_fragment(10, dtype=cutlass.Int32)
    for i, x in enumerate(xs):
    vec[i] = x
    kernel_use_vec_as_arg(vec.load(), res).launch(grid=[1, 1, 1], block=[1, 1, 1])


    import torch

    res = torch.zeros(10, dtype=torch.int32, device="cuda")
    pass_list_as_vector([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], from_dlpack(res))
    torch.cuda.synchronize()
    print(res)