Skip to content

Instantly share code, notes, and snippets.

@jeromeku
Forked from Observer007/test.py
Created August 27, 2025 01:25
Show Gist options
  • Save jeromeku/322c6fa1dba15ced54253f512ab721b9 to your computer and use it in GitHub Desktop.
Save jeromeku/322c6fa1dba15ced54253f512ab721b9 to your computer and use it in GitHub Desktop.
cute dsl inline_asm returns more than one values
import cutlass
import cutlass.cute as cute
from cutlass._mlir.dialects import llvm
from cutlass._mlir.extras import types as T
def compare_and_swap_i32(a: cutlass.Int32, b: cutlass.Int32) -> tuple[cutlass.Int32, cutlass.Int32]:
out_i32x2 = llvm.inline_asm(
llvm.StructType.get_literal([T.i32(), T.i32()]),
[cutlass.Int32(a).ir_value(), cutlass.Int32(b).ir_value()],
"{\n\t"
".reg .pred p;\n\t"
"setp.ge.s32 p, $2, $3;\n\t"
"selp.s32 $0, $2, $3, p;\n\t"
"selp.s32 $1, $2, $3, !p;\n\t"
"}\n",
"=r,=r,r,r",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
res0 = cutlass.Int32(
llvm.extractvalue(T.i32(), out_i32x2, [0])
)
res1 = cutlass.Int32(
llvm.extractvalue(T.i32(), out_i32x2, [1])
)
return res0, res1
@cute.kernel
def test(a: cutlass.Int32, b: cutlass.Int32):
c, d = compare_and_swap_i32(a, b)
cute.printf(c, d)
@cute.jit
def host_test(a: cutlass.Int32, b: cutlass.Int32):
test(a, b).launch(grid=(1, 1, 1), block=(1, 1, 1))
cutlass.cuda.initialize_cuda_context()
compiled_test = cute.compile(host_test, 1, 2)
compiled_test(1, 2) # print(2, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment