import cutlass import cutlass.cute as cute from cutlass._mlir.dialects import llvm from cutlass._mlir.extras import types as T def compare_and_swap_i32(a: cutlass.Int32, b: cutlass.Int32) -> tuple[cutlass.Int32, cutlass.Int32]: out_i32x2 = llvm.inline_asm( llvm.StructType.get_literal([T.i32(), T.i32()]), [cutlass.Int32(a).ir_value(), cutlass.Int32(b).ir_value()], "{\n\t" ".reg .pred p;\n\t" "setp.ge.s32 p, $2, $3;\n\t" "selp.s32 $0, $2, $3, p;\n\t" "selp.s32 $1, $2, $3, !p;\n\t" "}\n", "=r,=r,r,r", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) res0 = cutlass.Int32( llvm.extractvalue(T.i32(), out_i32x2, [0]) ) res1 = cutlass.Int32( llvm.extractvalue(T.i32(), out_i32x2, [1]) ) return res0, res1 @cute.kernel def test(a: cutlass.Int32, b: cutlass.Int32): c, d = compare_and_swap_i32(a, b) cute.printf(c, d) @cute.jit def host_test(a: cutlass.Int32, b: cutlass.Int32): test(a, b).launch(grid=(1, 1, 1), block=(1, 1, 1)) cutlass.cuda.initialize_cuda_context() compiled_test = cute.compile(host_test, 1, 2) compiled_test(1, 2) # print(2, 1)