Skip to content

Instantly share code, notes, and snippets.

@robertknight
Created December 29, 2023 18:20
Show Gist options
  • Save robertknight/d95b9a6c6ac79ef8bf64cea9d534b177 to your computer and use it in GitHub Desktop.
Save robertknight/d95b9a6c6ac79ef8bf64cea9d534b177 to your computer and use it in GitHub Desktop.

Revisions

  1. robertknight created this gist Dec 29, 2023.
    214 changes: 214 additions & 0 deletions avx-512-gemm-kernel.s
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,214 @@
    # At entry params are:
    #
    # tile_ptr (rdi)
    # tile_row_stride (rsi)
    # a (rdx, rcx)
    # b (r8, r9)
    # depth (stack)
    # alpha (xmm0)
    # beta (xmm1)
    .section __TEXT,__text,regular,pure_instructions
    .p2align 4, 0x90
    wasnn::gemm::kernels::x64::Avx512Kernel::kernel_avx_512:
    Lfunc_begin758:
    .cfi_startproc
    push rbp
    .cfi_def_cfa_offset 16
    .cfi_offset rbp, -16
    mov rbp, rsp
    .cfi_def_cfa_register rbp
    mov rax, qword ptr [rbp + 16] # Load `depth` into rax
    lea r10, [rax + rax]
    lea r10, [r10 + 2*r10] # Set r10 = `depth * MR` (where MR == 6)
    cmp r10, rcx # Compare `a.len()` with `depth * MR`
    ja LBB758_15
    mov rcx, rax # Set rcx = `depth * NR` (where NR == 2)
    shl rcx, 5
    cmp rcx, r9
    ja LBB758_16
    test rax, rax # Check if there are any loop iterations
    je LBB758_3
    add rdx, 20
    add r8, 64
    # Clear registers that hold `tmp`. The registers that hold `b_rows` don't
    # need to be cleared because they are dead stores.
    vxorps xmm2, xmm2, xmm2
    vxorps xmm3, xmm3, xmm3
    vxorps xmm4, xmm4, xmm4
    vxorps xmm5, xmm5, xmm5
    vxorps xmm6, xmm6, xmm6
    vxorps xmm7, xmm7, xmm7
    vxorps xmm8, xmm8, xmm8
    vxorps xmm9, xmm9, xmm9
    vxorps xmm10, xmm10, xmm10
    vxorps xmm11, xmm11, xmm11
    vxorps xmm12, xmm12, xmm12
    vxorps xmm13, xmm13, xmm13
    .p2align 4, 0x90
    LBB758_8:
    # Load `b_rows[i]`
    vmovups zmm14, zmmword ptr [r8 - 64]
    vmovups zmm15, zmmword ptr [r8]
    # tmp[i][j] = fmadd(broadcast(a[i]), b_rows[j])
    vbroadcastss zmm16, dword ptr [rdx - 20]
    vfmadd231ps zmm13, zmm16, zmm14
    vfmadd231ps zmm12, zmm15, zmm16
    vbroadcastss zmm16, dword ptr [rdx - 16]
    vfmadd231ps zmm11, zmm16, zmm14
    vfmadd231ps zmm10, zmm15, zmm16
    vbroadcastss zmm16, dword ptr [rdx - 12]
    vfmadd231ps zmm9, zmm16, zmm14
    vfmadd231ps zmm8, zmm15, zmm16
    vbroadcastss zmm16, dword ptr [rdx - 8]
    vfmadd231ps zmm7, zmm16, zmm14
    vfmadd231ps zmm6, zmm15, zmm16
    vbroadcastss zmm16, dword ptr [rdx - 4]
    vfmadd231ps zmm5, zmm16, zmm14
    vfmadd231ps zmm4, zmm15, zmm16
    vbroadcastss zmm16, dword ptr [rdx]
    vfmadd231ps zmm3, zmm16, zmm14
    vfmadd231ps zmm2, zmm16, zmm15
    add rdx, 24
    sub r8, -128
    dec rax
    jne LBB758_8 # Jump to top of depth loop if not final iteration
    vucomiss xmm0, dword ptr [rip + LCPI758_0]
    jne LBB758_9
    jnp LBB758_5
    jmp LBB758_9
    LBB758_3:
    vxorps xmm2, xmm2, xmm2
    vxorps xmm3, xmm3, xmm3
    vxorps xmm4, xmm4, xmm4
    vxorps xmm5, xmm5, xmm5
    vxorps xmm6, xmm6, xmm6
    vxorps xmm7, xmm7, xmm7
    vxorps xmm8, xmm8, xmm8
    vxorps xmm9, xmm9, xmm9
    vxorps xmm10, xmm10, xmm10
    vxorps xmm11, xmm11, xmm11
    vxorps xmm12, xmm12, xmm12
    vxorps xmm13, xmm13, xmm13
    # Test if `alpha == 1`
    vucomiss xmm0, dword ptr [rip + LCPI758_0]
    jne LBB758_9
    jp LBB758_9
    LBB758_5:
    # Test if `beta == 0`
    vxorps xmm14, xmm14, xmm14
    vucomiss xmm1, xmm14
    jne LBB758_9
    jp LBB758_9
    # Store `tmp[i][j]` to `tile_ptr`
    vmovups zmmword ptr [rdi], zmm13
    vmovups zmmword ptr [rdi + 64], zmm12
    vmovups zmmword ptr [rdi + 4*rsi], zmm11
    vmovups zmmword ptr [rdi + 4*rsi + 64], zmm10
    vmovups zmmword ptr [rdi + 8*rsi], zmm9
    vmovups zmmword ptr [rdi + 8*rsi + 64], zmm8
    lea rax, [rsi + 2*rsi]
    vmovups zmmword ptr [rdi + 4*rax], zmm7
    vmovups zmmword ptr [rdi + 4*rax + 64], zmm6
    lea rax, [rsi + 4*rsi]
    shl rsi, 4
    vmovups zmmword ptr [rdi + rsi], zmm5
    vmovups zmmword ptr [rdi + rsi + 64], zmm4
    vmovups zmmword ptr [rdi + 4*rax], zmm3
    vmovups zmmword ptr [rdi + 4*rax + 64], zmm2
    pop rbp
    # See https://community.intel.com/t5/Intel-ISA-Extensions/What-is-the-status-of-VZEROUPPER-use/m-p/1098375
    vzeroupper
    ret
    LBB758_9:
    # Check if `beta == 1 && alpha == 1`
    vucomiss xmm0, dword ptr [rip + LCPI758_0]
    jne LBB758_12
    jp LBB758_12
    vucomiss xmm1, dword ptr [rip + LCPI758_0]
    jne LBB758_12
    jp LBB758_12
    vaddps zmm0, zmm13, zmmword ptr [rdi]
    vmovups zmmword ptr [rdi], zmm0
    vaddps zmm0, zmm12, zmmword ptr [rdi + 64]
    vmovups zmmword ptr [rdi + 64], zmm0
    vaddps zmm0, zmm11, zmmword ptr [rdi + 4*rsi]
    vmovups zmmword ptr [rdi + 4*rsi], zmm0
    vaddps zmm0, zmm10, zmmword ptr [rdi + 4*rsi + 64]
    vmovups zmmword ptr [rdi + 4*rsi + 64], zmm0
    vaddps zmm0, zmm9, zmmword ptr [rdi + 8*rsi]
    vmovups zmmword ptr [rdi + 8*rsi], zmm0
    vaddps zmm0, zmm8, zmmword ptr [rdi + 8*rsi + 64]
    vmovups zmmword ptr [rdi + 8*rsi + 64], zmm0
    lea rax, [rsi + 2*rsi]
    vaddps zmm0, zmm7, zmmword ptr [rdi + 4*rax]
    vmovups zmmword ptr [rdi + 4*rax], zmm0
    vaddps zmm0, zmm6, zmmword ptr [rdi + 4*rax + 64]
    vmovups zmmword ptr [rdi + 4*rax + 64], zmm0
    lea rax, [rsi + 4*rsi]
    shl rsi, 4
    vaddps zmm0, zmm5, zmmword ptr [rdi + rsi]
    vmovups zmmword ptr [rdi + rsi], zmm0
    vaddps zmm0, zmm4, zmmword ptr [rdi + rsi + 64]
    vmovups zmmword ptr [rdi + rsi + 64], zmm0
    vaddps zmm0, zmm3, zmmword ptr [rdi + 4*rax]
    vmovups zmmword ptr [rdi + 4*rax], zmm0
    vaddps zmm0, zmm2, zmmword ptr [rdi + 4*rax + 64]
    vmovups zmmword ptr [rdi + 4*rax + 64], zmm0
    pop rbp
    vzeroupper
    ret
    LBB758_12:
    vbroadcastss zmm0, xmm0
    vbroadcastss zmm1, xmm1
    vmulps zmm14, zmm1, zmmword ptr [rdi]
    vfmadd213ps zmm13, zmm0, zmm14
    vmovups zmmword ptr [rdi], zmm13
    vmulps zmm13, zmm1, zmmword ptr [rdi + 64]
    vfmadd213ps zmm12, zmm0, zmm13
    vmovups zmmword ptr [rdi + 64], zmm12
    vmulps zmm12, zmm1, zmmword ptr [rdi + 4*rsi]
    vfmadd213ps zmm11, zmm0, zmm12
    vmovups zmmword ptr [rdi + 4*rsi], zmm11
    vmulps zmm11, zmm1, zmmword ptr [rdi + 4*rsi + 64]
    vfmadd213ps zmm10, zmm0, zmm11
    vmovups zmmword ptr [rdi + 4*rsi + 64], zmm10
    vmulps zmm10, zmm1, zmmword ptr [rdi + 8*rsi]
    vfmadd213ps zmm9, zmm0, zmm10
    vmovups zmmword ptr [rdi + 8*rsi], zmm9
    vmulps zmm9, zmm1, zmmword ptr [rdi + 8*rsi + 64]
    vfmadd213ps zmm8, zmm0, zmm9
    vmovups zmmword ptr [rdi + 8*rsi + 64], zmm8
    lea rax, [rsi + 2*rsi]
    vmulps zmm8, zmm1, zmmword ptr [rdi + 4*rax]
    vfmadd213ps zmm7, zmm0, zmm8
    vmovups zmmword ptr [rdi + 4*rax], zmm7
    vmulps zmm7, zmm1, zmmword ptr [rdi + 4*rax + 64]
    vfmadd213ps zmm6, zmm0, zmm7
    vmovups zmmword ptr [rdi + 4*rax + 64], zmm6
    lea rax, [rsi + 4*rsi]
    shl rsi, 4
    vmulps zmm6, zmm1, zmmword ptr [rdi + rsi]
    vfmadd213ps zmm5, zmm0, zmm6
    vmovups zmmword ptr [rdi + rsi], zmm5
    vmulps zmm5, zmm1, zmmword ptr [rdi + rsi + 64]
    vfmadd213ps zmm4, zmm0, zmm5
    vmovups zmmword ptr [rdi + rsi + 64], zmm4
    vmulps zmm4, zmm1, zmmword ptr [rdi + 4*rax]
    vfmadd213ps zmm3, zmm0, zmm4
    vmovups zmmword ptr [rdi + 4*rax], zmm3
    vmulps zmm1, zmm1, zmmword ptr [rdi + 4*rax + 64]
    vfmadd213ps zmm2, zmm0, zmm1
    vmovups zmmword ptr [rdi + 4*rax + 64], zmm2
    pop rbp
    vzeroupper
    ret
    LBB758_15:
    lea rdi, [rip + l___unnamed_545]
    lea rdx, [rip + l___unnamed_551]
    mov esi, 39
    call core::panicking::panic
    LBB758_16:
    lea rdi, [rip + l___unnamed_547]
    lea rdx, [rip + l___unnamed_552]
    mov esi, 39
    call core::panicking::panic