-
-
Save jeromeku/322c6fa1dba15ced54253f512ab721b9 to your computer and use it in GitHub Desktop.
cute dsl inline_asm returns more than one values
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import cutlass | |
| import cutlass.cute as cute | |
| from cutlass._mlir.dialects import llvm | |
| from cutlass._mlir.extras import types as T | |
| def compare_and_swap_i32(a: cutlass.Int32, b: cutlass.Int32) -> tuple[cutlass.Int32, cutlass.Int32]: | |
| out_i32x2 = llvm.inline_asm( | |
| llvm.StructType.get_literal([T.i32(), T.i32()]), | |
| [cutlass.Int32(a).ir_value(), cutlass.Int32(b).ir_value()], | |
| "{\n\t" | |
| ".reg .pred p;\n\t" | |
| "setp.ge.s32 p, $2, $3;\n\t" | |
| "selp.s32 $0, $2, $3, p;\n\t" | |
| "selp.s32 $1, $2, $3, !p;\n\t" | |
| "}\n", | |
| "=r,=r,r,r", | |
| has_side_effects=False, | |
| is_align_stack=False, | |
| asm_dialect=llvm.AsmDialect.AD_ATT, | |
| ) | |
| res0 = cutlass.Int32( | |
| llvm.extractvalue(T.i32(), out_i32x2, [0]) | |
| ) | |
| res1 = cutlass.Int32( | |
| llvm.extractvalue(T.i32(), out_i32x2, [1]) | |
| ) | |
| return res0, res1 | |
| @cute.kernel | |
| def test(a: cutlass.Int32, b: cutlass.Int32): | |
| c, d = compare_and_swap_i32(a, b) | |
| cute.printf(c, d) | |
| @cute.jit | |
| def host_test(a: cutlass.Int32, b: cutlass.Int32): | |
| test(a, b).launch(grid=(1, 1, 1), block=(1, 1, 1)) | |
| cutlass.cuda.initialize_cuda_context() | |
| compiled_test = cute.compile(host_test, 1, 2) | |
| compiled_test(1, 2) # print(2, 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment