Created
December 29, 2023 18:20
-
-
Save robertknight/d95b9a6c6ac79ef8bf64cea9d534b177 to your computer and use it in GitHub Desktop.
Revisions
-
robertknight created this gist
Dec 29, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,214 @@ # At entry params are: # # tile_ptr (rdi) # tile_row_stride (rsi) # a (rdx, rcx) # b (r8, r9) # depth (stack) # alpha (xmm0) # beta (xmm1) .section __TEXT,__text,regular,pure_instructions .p2align 4, 0x90 wasnn::gemm::kernels::x64::Avx512Kernel::kernel_avx_512: Lfunc_begin758: .cfi_startproc push rbp .cfi_def_cfa_offset 16 .cfi_offset rbp, -16 mov rbp, rsp .cfi_def_cfa_register rbp mov rax, qword ptr [rbp + 16] # Load `depth` into rax lea r10, [rax + rax] lea r10, [r10 + 2*r10] # Set r10 = `depth * MR` (where MR == 6) cmp r10, rcx # Compare `a.len()` with `depth * MR` ja LBB758_15 mov rcx, rax # Set rcx = `depth * NR` (where NR == 2) shl rcx, 5 cmp rcx, r9 ja LBB758_16 test rax, rax # Check if there are any loop iterations je LBB758_3 add rdx, 20 add r8, 64 # Clear registers that hold `tmp`. The registers that hold `b_rows` don't # need to be cleared because they are dead stores. vxorps xmm2, xmm2, xmm2 vxorps xmm3, xmm3, xmm3 vxorps xmm4, xmm4, xmm4 vxorps xmm5, xmm5, xmm5 vxorps xmm6, xmm6, xmm6 vxorps xmm7, xmm7, xmm7 vxorps xmm8, xmm8, xmm8 vxorps xmm9, xmm9, xmm9 vxorps xmm10, xmm10, xmm10 vxorps xmm11, xmm11, xmm11 vxorps xmm12, xmm12, xmm12 vxorps xmm13, xmm13, xmm13 .p2align 4, 0x90 LBB758_8: # Load `b_rows[i]` vmovups zmm14, zmmword ptr [r8 - 64] vmovups zmm15, zmmword ptr [r8] # tmp[i][j] = fmadd(broadcast(a[i]), b_rows[j]) vbroadcastss zmm16, dword ptr [rdx - 20] vfmadd231ps zmm13, zmm16, zmm14 vfmadd231ps zmm12, zmm15, zmm16 vbroadcastss zmm16, dword ptr [rdx - 16] vfmadd231ps zmm11, zmm16, zmm14 vfmadd231ps zmm10, zmm15, zmm16 vbroadcastss zmm16, dword ptr [rdx - 12] vfmadd231ps zmm9, zmm16, zmm14 vfmadd231ps zmm8, zmm15, zmm16 vbroadcastss zmm16, dword ptr [rdx - 8] vfmadd231ps zmm7, zmm16, zmm14 vfmadd231ps zmm6, zmm15, zmm16 vbroadcastss zmm16, dword ptr [rdx - 4] vfmadd231ps zmm5, zmm16, zmm14 vfmadd231ps zmm4, zmm15, zmm16 vbroadcastss zmm16, dword ptr [rdx] vfmadd231ps zmm3, zmm16, zmm14 vfmadd231ps zmm2, zmm16, zmm15 add rdx, 24 sub r8, -128 dec rax jne LBB758_8 # Jump to top of depth loop if not final iteration vucomiss xmm0, dword ptr [rip + LCPI758_0] jne LBB758_9 jnp LBB758_5 jmp LBB758_9 LBB758_3: vxorps xmm2, xmm2, xmm2 vxorps xmm3, xmm3, xmm3 vxorps xmm4, xmm4, xmm4 vxorps xmm5, xmm5, xmm5 vxorps xmm6, xmm6, xmm6 vxorps xmm7, xmm7, xmm7 vxorps xmm8, xmm8, xmm8 vxorps xmm9, xmm9, xmm9 vxorps xmm10, xmm10, xmm10 vxorps xmm11, xmm11, xmm11 vxorps xmm12, xmm12, xmm12 vxorps xmm13, xmm13, xmm13 # Test if `alpha == 1` vucomiss xmm0, dword ptr [rip + LCPI758_0] jne LBB758_9 jp LBB758_9 LBB758_5: # Test if `beta == 0` vxorps xmm14, xmm14, xmm14 vucomiss xmm1, xmm14 jne LBB758_9 jp LBB758_9 # Store `tmp[i][j]` to `tile_ptr` vmovups zmmword ptr [rdi], zmm13 vmovups zmmword ptr [rdi + 64], zmm12 vmovups zmmword ptr [rdi + 4*rsi], zmm11 vmovups zmmword ptr [rdi + 4*rsi + 64], zmm10 vmovups zmmword ptr [rdi + 8*rsi], zmm9 vmovups zmmword ptr [rdi + 8*rsi + 64], zmm8 lea rax, [rsi + 2*rsi] vmovups zmmword ptr [rdi + 4*rax], zmm7 vmovups zmmword ptr [rdi + 4*rax + 64], zmm6 lea rax, [rsi + 4*rsi] shl rsi, 4 vmovups zmmword ptr [rdi + rsi], zmm5 vmovups zmmword ptr [rdi + rsi + 64], zmm4 vmovups zmmword ptr [rdi + 4*rax], zmm3 vmovups zmmword ptr [rdi + 4*rax + 64], zmm2 pop rbp # See https://community.intel.com/t5/Intel-ISA-Extensions/What-is-the-status-of-VZEROUPPER-use/m-p/1098375 vzeroupper ret LBB758_9: # Check if `beta == 1 && alpha == 1` vucomiss xmm0, dword ptr [rip + LCPI758_0] jne LBB758_12 jp LBB758_12 vucomiss xmm1, dword ptr [rip + LCPI758_0] jne LBB758_12 jp LBB758_12 vaddps zmm0, zmm13, zmmword ptr [rdi] vmovups zmmword ptr [rdi], zmm0 vaddps zmm0, zmm12, zmmword ptr [rdi + 64] vmovups zmmword ptr [rdi + 64], zmm0 vaddps zmm0, zmm11, zmmword ptr [rdi + 4*rsi] vmovups zmmword ptr [rdi + 4*rsi], zmm0 vaddps zmm0, zmm10, zmmword ptr [rdi + 4*rsi + 64] vmovups zmmword ptr [rdi + 4*rsi + 64], zmm0 vaddps zmm0, zmm9, zmmword ptr [rdi + 8*rsi] vmovups zmmword ptr [rdi + 8*rsi], zmm0 vaddps zmm0, zmm8, zmmword ptr [rdi + 8*rsi + 64] vmovups zmmword ptr [rdi + 8*rsi + 64], zmm0 lea rax, [rsi + 2*rsi] vaddps zmm0, zmm7, zmmword ptr [rdi + 4*rax] vmovups zmmword ptr [rdi + 4*rax], zmm0 vaddps zmm0, zmm6, zmmword ptr [rdi + 4*rax + 64] vmovups zmmword ptr [rdi + 4*rax + 64], zmm0 lea rax, [rsi + 4*rsi] shl rsi, 4 vaddps zmm0, zmm5, zmmword ptr [rdi + rsi] vmovups zmmword ptr [rdi + rsi], zmm0 vaddps zmm0, zmm4, zmmword ptr [rdi + rsi + 64] vmovups zmmword ptr [rdi + rsi + 64], zmm0 vaddps zmm0, zmm3, zmmword ptr [rdi + 4*rax] vmovups zmmword ptr [rdi + 4*rax], zmm0 vaddps zmm0, zmm2, zmmword ptr [rdi + 4*rax + 64] vmovups zmmword ptr [rdi + 4*rax + 64], zmm0 pop rbp vzeroupper ret LBB758_12: vbroadcastss zmm0, xmm0 vbroadcastss zmm1, xmm1 vmulps zmm14, zmm1, zmmword ptr [rdi] vfmadd213ps zmm13, zmm0, zmm14 vmovups zmmword ptr [rdi], zmm13 vmulps zmm13, zmm1, zmmword ptr [rdi + 64] vfmadd213ps zmm12, zmm0, zmm13 vmovups zmmword ptr [rdi + 64], zmm12 vmulps zmm12, zmm1, zmmword ptr [rdi + 4*rsi] vfmadd213ps zmm11, zmm0, zmm12 vmovups zmmword ptr [rdi + 4*rsi], zmm11 vmulps zmm11, zmm1, zmmword ptr [rdi + 4*rsi + 64] vfmadd213ps zmm10, zmm0, zmm11 vmovups zmmword ptr [rdi + 4*rsi + 64], zmm10 vmulps zmm10, zmm1, zmmword ptr [rdi + 8*rsi] vfmadd213ps zmm9, zmm0, zmm10 vmovups zmmword ptr [rdi + 8*rsi], zmm9 vmulps zmm9, zmm1, zmmword ptr [rdi + 8*rsi + 64] vfmadd213ps zmm8, zmm0, zmm9 vmovups zmmword ptr [rdi + 8*rsi + 64], zmm8 lea rax, [rsi + 2*rsi] vmulps zmm8, zmm1, zmmword ptr [rdi + 4*rax] vfmadd213ps zmm7, zmm0, zmm8 vmovups zmmword ptr [rdi + 4*rax], zmm7 vmulps zmm7, zmm1, zmmword ptr [rdi + 4*rax + 64] vfmadd213ps zmm6, zmm0, zmm7 vmovups zmmword ptr [rdi + 4*rax + 64], zmm6 lea rax, [rsi + 4*rsi] shl rsi, 4 vmulps zmm6, zmm1, zmmword ptr [rdi + rsi] vfmadd213ps zmm5, zmm0, zmm6 vmovups zmmword ptr [rdi + rsi], zmm5 vmulps zmm5, zmm1, zmmword ptr [rdi + rsi + 64] vfmadd213ps zmm4, zmm0, zmm5 vmovups zmmword ptr [rdi + rsi + 64], zmm4 vmulps zmm4, zmm1, zmmword ptr [rdi + 4*rax] vfmadd213ps zmm3, zmm0, zmm4 vmovups zmmword ptr [rdi + 4*rax], zmm3 vmulps zmm1, zmm1, zmmword ptr [rdi + 4*rax + 64] vfmadd213ps zmm2, zmm0, zmm1 vmovups zmmword ptr [rdi + 4*rax + 64], zmm2 pop rbp vzeroupper ret LBB758_15: lea rdi, [rip + l___unnamed_545] lea rdx, [rip + l___unnamed_551] mov esi, 39 call core::panicking::panic LBB758_16: lea rdi, [rip + l___unnamed_547] lea rdx, [rip + l___unnamed_552] mov esi, 39 call core::panicking::panic