# GGML CUDA/HIP Inference Paths and Precision by Architecture This document summarizes how ggml’s CUDA/HIP backend executes inference on different GPU families, which code paths are used, and at what numeric precision the major compute happens. It also provides rough workload composition percentages to relate paths to each architecture’s FLOPS/TOPs. References are to files under `ggml/src/ggml-cuda` unless noted. - Matmul (quantized): `mmq.cu`, `mmq.cuh`, `vecdotq.cuh`, `quantize.cu/.cuh` - Matmul (float): `mmf.cu`, `mmvf.cu`, cuBLAS/hipBLAS calls in `ggml-cuda.cu` - FlashAttention: `fattn*.cu/.cuh` - Softmax: `softmax.cu` - Norms: `norm.cu` - RoPE: `rope.cu` - Embedding/time-step embed: `getrows.cu`, `tsembd.cu` - Feature detection + arch gates: `common.cuh`, `mma.cuh`, `vendors/hip.h` ## Workload Composition (typical decode, single token) These are model- and context-length dependent. As a rule of thumb for standard LLMs (e.g., LLaMA family) in quantized inference: - GEMMs with quantized weights (Q, K, V, output projection, MLP up/gate/down): ~85–95% of total compute - Executed via INT8 kernels (MMQ) using DP4A/TensorCore/MFMA as available - Accumulation in FP32; outputs FP32 - Attention kernels (Q·K^T, softmax, P·V): ~3–12% - Q·K^T and P·V tiles often run with FP16 MMAs (if supported) or FP16/FP32 vector/tile kernels - Softmax is FP32 - “Glue” ops (RoPE, norms, masking, embeddings, residual adds, elementwise): ~1–5% combined, memory‑bound - On‑the‑fly activation quantization (q8_1 tiling for MMQ): typically low single digit % For short contexts, attention share is near the low end; for long contexts (large K/V window), attention grows toward the high end. ## Operation Precision Map (by op type) - Weight GEMM (activation × quantized weight) - Path: MMQ kernels (`mmq.cu/.cuh`) - Activations are re‑tiled/quantized to `q8_1` per tile (`quantize.cu`), weights remain in their quant format (Q4/Q5/Q8/k‑quants/IQ*) - Compute: INT8×INT8 dot (DP4A / TensorCore MMA int8 / AMD MFMA int8) - Accumulation: INT32 partial → scaled to FP32 outputs - Float GEMM (if unquantized or fallback): cuBLAS/hipBLAS or custom MMF/MMVF - Precision: F16/BF16/F32 depending on tensor types and hardware - FlashAttention (Q·K^T, softmax, P·V) - FP16 MMA when available on the platform (NVIDIA Volta+/Turing+/Ampere+/Ada; AMD via rocWMMA when enabled) - Fallback vector/tile kernels in F16/F32 - Softmax is always FP32 (`softmax.cu`) - RoPE: FP32 math (`rope.cu`) - RMSNorm/LayerNorm: FP32 (`norm.cu`) - Embedding/getrows/timestep: FP32 (`getrows.cu`, `tsembd.cu`) ## How the Matmul Path Is Chosen Entry points select among: - Quantized GEMM (MMQ): `ggml_cuda_mul_mat_q` / `ggml_cuda_op_mul_mat_q` (`mmq.cu`) - Float GEMM (MMF/MMVF) or cuBLAS/hipBLAS: `ggml_cuda_op_mul_mat_f/vec_f` or `ggml_cuda_mul_mat*_cublas` (`ggml-cuda.cu`) Heuristics for MMQ usage: `bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11)` (`mmq.cu`) - Enabled for supported quant types (Q4_0/1, Q5_0/1, Q8_0, MXFP4, k‑quants, IQ*) - NVIDIA: - If Tensor Cores present (Turing+), prefer MMQ except when FP16 tensor cores are faster at large batch; if FP16 tensor cores present and batch is large, it may choose float paths; if FP16 TC absent or batch small (ne11 < 64), choose DP4A MMQ - AMD CDNA: - Prefer MMQ (uses MFMA int8). On CDNA3 (MI300), MMQ is forced due to rocBLAS/hipBLASLt issues noted in code - For small/medium batches and certain quant types, MMQ is preferred - AMD RDNA2/3/4: - Uses MMQ via DP4A‑like sdot/sudot only for small batches (ne11 < 64); otherwise tends toward BLAS ## Architecture Tables (execution paths and precisions) ### NVIDIA GPUs | Architecture | Weight GEMM (quant) | Attention | Float GEMM (unquant/fallback) | Notes | |---|---|---|---|---| | Pascal (SM 61) | MMQ via DP4A (INT8) | FA fallback vector/tile F16/F32; Softmax F32 | cuBLAS F32/F16 (no TCs) | No Tensor Cores for INT8/F16. DP4A available. | | Volta (SM 70) | MMQ via DP4A (INT8) | FA FP16 MMA (Tensor Cores); Softmax F32 | cuBLAS F16/F32 (TCs for FP16) | INT8 TCs not available; FP16 MMA available. | | Turing (SM 75) | MMQ via INT8 Tensor Cores (mma.sync s8) or DP4A for small shapes | FA FP16 MMA; Softmax F32 | cuBLAS TC paths | INT8 TCs available; heuristics may switch by batch. | | Ampere (SM 80) | MMQ via INT8 Tensor Cores | FA FP16 MMA; Softmax F32 | cuBLAS TC paths (TF32, FP16) | cp.async etc. not user‑visible in MMQ selection. | | Ada (SM 89) | MMQ via INT8 Tensor Cores | FA FP16 MMA; Softmax F32 | cuBLAS TC paths | Similar to Ampere, higher clocks/throughput. | Precision summary (NVIDIA): weights GEMMs primarily INT8 compute, FP32 accum; attention FP16 MMA; softmax/norm/rope FP32. ### AMD GPUs (HIP/ROCm) | Architecture | Weight GEMM (quant) | Attention | Float GEMM (unquant/fallback) | Notes | |---|---|---|---|---| | CDNA1/2 (MI100/MI210) | MMQ via MFMA INT8 | FA FP16 MMA via rocWMMA if enabled; otherwise vector F16/F32; Softmax F32 | hipBLAS/rocBLAS | MFMA INT8 fast; FA WMMA requires `GGML_HIP_ROCWMMA_FATTN`. | | CDNA3 (MI300) | MMQ via MFMA INT8 (preferred for stability/perf) | FA FP16 MMA via rocWMMA if enabled; else vector | hipBLAS/rocBLASLt noted unstable in code; MMQ forced | Code comments prefer MMQ on CDNA3 regardless. | | RDNA2 (RX 6000) | MMQ via DP4A‑like (`__builtin_amdgcn_sdot4`) for small batches; else BLAS | FA WMMA not generally enabled; vector F16/F32; Softmax F32 | hipBLAS/rocBLAS | dp4a‑equiv for INT8; no MFMA. | | RDNA3/4 (RX 7000/RX 9000) | MMQ via DP4A‑like (`__builtin_amdgcn_sudot4`) only when `ne11 < 64`; else BLAS | FA FP16 WMMA via rocWMMA if compiled with `GGML_HIP_ROCWMMA_FATTN`; else vector | hipBLAS/rocBLAS | WMMA available via rocWMMA on RDNA3/4 when enabled; INT8 MFMA not on RDNA. | | RDNA1 | MMQ functional but DP4A not native (emulated), performance limited | FA vector F16/F32; Softmax F32 | hipBLAS/rocBLAS | Old arch; limited INT8 support. | HIP/ROCm INT8 support: Yes. Implemented via AMD MFMA int8 on CDNA and via sdot/sudot dot products on RDNA2/3/4, with fallbacks/emulation where needed (`common.cuh: ggml_cuda_dp4a`). ## What each path computes - MMQ (quant GEMM) - Tiles activations to `block_q8_1_mmq` (contains quantized int8 data and per‑block scales/partials), loaded into shared memory (`quantize.cu`) - Loads quantized weight blocks (Q4/Q5/Q8/k/IQ variants) and performs INT8×INT8→INT32 partial dot products (`vecdotq.cuh`) using: - NVIDIA: DP4A or `mma.sync.*.s8.s8.s32` - AMD: `__builtin_amdgcn_mfma_i32_*i8` (CDNA) or `__builtin_amdgcn_sdot4/sudot4` (RDNA) - Applies scales and merges to FP32 outputs - FlashAttention (FA) - If FP16 MMA available, uses tensor core tiles for Q·K^T and P·V; otherwise vector/tile kernels in F16/F32 (`fattn-*.cu`) - Softmax is FP32 always (`softmax.cu`) - Float GEMM (unquantized/fallback) - cuBLAS/hipBLAS calls (F16/BF16/F32) or custom MMF/MMVF for small shapes - Other ops - Norms, RoPE, embeddings, masking: FP32 ## Path selection details (high level) - Top‑level mul_mat: `ggml-cuda.cu` - Chooses between quantized (MMQ), float kernels (MMF/MMVF), and BLAS based on tensor types, transposition, sizes, and per‑arch feature checks - For quantized weights and FP32 activations, MMQ is preferred when `ggml_cuda_should_use_mmq(...)` is true - MMQ device heuristics: `mmq.cu: ggml_cuda_should_use_mmq` - NVIDIA: prefer MMQ on Turing+; on pre‑TC (Pascal/Volta) use DP4A when batch small, else may switch to FP16 TC GEMMs when faster - AMD CDNA: prefer MMQ (MFMA); CDNA3 forces MMQ - AMD RDNA2/3/4: MMQ used for small batches; BLAS otherwise - Attention: - FP16 MMA path enabled when `FP16_MMA_AVAILABLE` (NVIDIA Volta+ or HIP with rocWMMA on supported arch/flags) - Otherwise use F16/F32 vector/tile fallbacks ## Approximate compute % by architecture (decode step) These ranges assume common transformer configs with quantized weights and FP16/FP32 activations, single‑token decode. Percentages shift with sequence length and model size. - NVIDIA (Pascal → Ada) - INT8 MMQ (weight GEMMs): 85–95% - FlashAttention (Q·K^T, P·V): 3–10% (FP16 MMA on Volta+; vector on Pascal) - Softmax/Mask/Norm/RoPE/Embedding: 1–5% total (FP32) - Quantize/dequantize overhead: ~1–3% - AMD CDNA (MI100/MI210/MI300) - INT8 MMQ via MFMA: 85–95% - FlashAttention: 3–10% (FP16 WMMA if rocWMMA; else vector) - Other FP32 ops: 1–5% - Quantize overhead: ~1–3% - AMD RDNA2/3/4 - For small batches (`ne11 < 64`), INT8 MMQ (sdot/sudot) dominates similarly (80–90%) - For larger batches, BLAS float GEMMs dominate; FA vector or WMMA (RDNA3/4 with rocWMMA) adds 5–12% - Other FP32 ops: 1–5% Caveats: exact shares depend on batch size, head dims, FFN width, sequence length, cache layout, and kernel tiling. ## Build/flag notes (HIP) - INT8 MFMA in MMQ requires CDNA and `!GGML_HIP_NO_MMQ_MFMA` (default enabled) - FA tensor core path on HIP requires `GGML_HIP_ROCWMMA_FATTN` (and `rocwmma` present); enabled for CDNA and RDNA3/4 (optionally gated for RDNA4 via `GGML_HIP_ROCWMMA_FATTN_GFX12`) - For RDNA2/3/4, INT8 dot uses `__builtin_amdgcn_sdot4/sudot4`; RDNA1 uses a slower fallback/emulation ## Practical takeaways - For quantized models on NVIDIA Turing/Ampere/Ada and AMD CDNA, virtually all heavy GEMMs run as INT8 (tensor core/MFMA), so achievable TOPs are the key limiter; FA runs in FP16 on tensor cores where possible - On NVIDIA Pascal or AMD RDNA without MFMA, MMQ uses DP4A‑class dot where available or emulation; throughput is lower, and BLAS float GEMMs may be chosen for larger batches on RDNA - Softmax, norms, rotaries, and embeddings remain FP32 across architectures and contribute modestly to total compute ## RDNA3 long‑context slowdown vs NVIDIA Observation: On RDNA3, short context decode (e.g., pp512) may be ~2× slower than a comparable NVIDIA GPU, but at long context (e.g., pp512+d32768) the gap can widen to ~5×. The primary reason is that the attention portion grows with context length and RDNA3’s default attention path is significantly slower unless rocWMMA is enabled and used by FlashAttention (FA). Why it grows with context - For single‑token decode, the FFNs (weight GEMMs) are roughly constant work per token, while attention work increases with the amount of K/V you attend to (more memory traffic and dot products along K; FA reduces traffic but still scales with sequence length). - On NVIDIA Turing/Ampere/Ada, FA uses FP16 tensor cores, keeping the attention portion fast as it grows. On RDNA3, if FA falls back to vector/tile kernels (no WMMA), attention throughput is much lower, so the attention share dominates total time at long context, widening the gap. RDNA3 specific factors - INT8 for weight GEMMs is fine: MMQ uses RDNA sdot/sudot (DP4A‑class) for INT8 and stays mostly constant vs context; the widening gap implicates attention, not MMQ. - FA path: WMMA on RDNA3 is not used unless `GGML_HIP_ROCWMMA_FATTN` is enabled at build time. Without it, RDNA3 uses the non‑WMMA FA kernels (F16/F32 vector), which are much slower at long K. - BLAS and shape heuristics: For larger batches/tiles, RDNA backends may fall back to BLAS or less optimal kernels; NVIDIA has mature TC kernels for FA and GEMM across many shapes. - Memory/cache sensitivity: Long‑context FA touches more K/V cache. Differences in cache hierarchy and compiler scheduling can hurt the fallback FA more on RDNA3 than the TC/WMMA path on NVIDIA. What to do: optimizations that materially help RDNA3 1) Enable rocWMMA FlashAttention on RDNA3 - Build with `-DGGML_HIP_ROCWMMA_FATTN=ON` and ensure `rocwmma` headers are available (CMake will check `rocwmma/rocwmma.hpp`). - This turns on `FP16_MMA_AVAILABLE` for HIP on RDNA3 in `common.cuh`, selecting the FA WMMA path in `fattn.cu`. - Expect 2–3× speedups for attention‑heavy segments; this directly addresses the long‑context gap. 2) Verify FA is actually taken and dimension‑matched - FA MMA cases are specialized for common head dims: 64, 80, 96, 112, 128, 256, 576. Ensure your model’s head dimension matches one of these so the MMA kernel is selected. - Check that GQA settings trigger the optimized branches (`fattn.cu` uses gqa_ratio divisibility checks). 3) Keep MMQ enabled for quantized weights (no change with context, but necessary for overall perf) - RDNA3 MMQ relies on sdot/sudot; it remains the dominant compute at short context. Confirm `ggml_cuda_should_use_mmq(...)` returns true for your quant types and batch sizes. 4) Toolchain and flags - Use ROCm ≥ 6.1 (as required by `ggml-hip/CMakeLists.txt`). - If building for RDNA4, add `-DGGML_HIP_ROCWMMA_FATTN_GFX12=ON` so WMMA FA is allowed on GFX12. - Ensure `CMAKE_PREFIX_PATH` includes ROCm cmake dirs and rocWMMA installation paths. 5) Secondary tunings (smaller gains) - Ensure softmax kernel launches with enough rows per block for your ncols; it is FP32 and can be memory‑bound at large nheads/ncols, but it is rarely the top bottleneck. - Keep K/V cache layout contiguous and avoid host<->device syncs in the decode loop. Why NVIDIA scales better at long context - NVIDIA’s FA uses FP16 tensor cores broadly, with kernels tuned across many shapes and large on‑chip bandwidth (ldmatrix→mma pipelines). As attention dominates with context, these kernels retain a high fraction of theoretical throughput. - RDNA3 without WMMA falls back to scalar/vector code, which amplifies the gap as attention workload increases. Enabling rocWMMA narrows this gap substantially. Summary - The widening 2×→5× gap at long contexts is primarily a software path issue on RDNA3: FA not using WMMA by default. Enabling `GGML_HIP_ROCWMMA_FATTN` and ensuring supported head dims typically yields 2–3× FA speedups and materially reduces the long‑context penalty.