// uses https://github.com/HazyResearch/ThunderKittens #include "tk/src/kittens.cuh" #include "tk/src/common/pyutils/torch_helpers.cuh" #define NUM_WORKERS 2 // This kernel uses this many workers in parallel per block, to help issue instructions more quickly. #define DIMENSION 64 // This kernel operates over 64-dimensional vectors #define DEBUG 0 using namespace kittens; // this kernel only handles headdim=q_reg.cols for simplicity. Also n should be a multiple of 256 here. template __global__ void causal_attend_kernel( int n, const H* __restrict__ __q__, const H* __restrict__ __k__, const H* __restrict__ __v__, const float* __restrict__ __f__, H* __o__ ) { auto warpid = kittens::warpid(); auto block_start = blockIdx.x*(n*DIMENSION); const bf16 *_q = reinterpret_cast(__q__) + block_start, *_k = reinterpret_cast(__k__) + block_start, *_v = reinterpret_cast(__v__) + block_start; bf16 *_o = reinterpret_cast(__o__) + block_start; extern __shared__ alignment_dummy __shm[]; // this is the CUDA shared memory shared_allocator al((int*)&__shm[0]); // K and V live in shared memory -- this is about all that will fit. st_bf_1x4 (&k_smem)[NUM_WORKERS] = al.allocate, NUM_WORKERS>(); st_bf_1x4 (&v_smem)[NUM_WORKERS] = al.allocate, NUM_WORKERS>(); // Initialize all of the register tiles. rt_bf_1x4<> q_reg, k_reg, v_reg; // v_reg need to be swapped into col_l rt_fl_1x1<> att_block; rt_bf_1x1<> att_block_mma; rt_fl_1x4<> o_reg; int qo_blocks = n / (q_reg.rows*NUM_WORKERS); int kv_blocks = n / (q_reg.rows*NUM_WORKERS); for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) { // each warp loads its own Q tile of 16x16 auto q_index = q_blk*NUM_WORKERS + warpid; load(q_reg, _q + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols); // zero flash attention O register. zero(o_reg); // iterate over k, v for these q's that have been loaded for(auto kv_idx = 0; kv_idx <= kv_blocks; kv_idx++) { int kv_warp_index = kv_idx*NUM_WORKERS + warpid; if (kv_warp_index <= q_index) { // ensure causality // each warp loads its own chunk of k, v into shared memory load(v_smem[warpid], _v + kv_warp_index*q_reg.num_elements, q_reg.cols); load(k_smem[warpid], _k + kv_warp_index*q_reg.num_elements, q_reg.cols); } __syncthreads(); // we need to make sure all memory is loaded before we can begin the compute phase // now each warp goes through all of the subtiles, loads them, and then does the flash attention internal alg. for(int subtile = 0; subtile < NUM_WORKERS; subtile++) { int kv_subtile_index = kv_idx*NUM_WORKERS + subtile; if (!(kv_subtile_index <= q_index)) { // ensure causality continue; } load(k_reg, k_smem[subtile]); // load k from shared into registers zero(att_block); // zero 16x16 attention tile mma_ABt(att_block, q_reg, k_reg, att_block); // Q@K.T copy(att_block_mma, att_block); // convert to bf16 for mma_AB if (kv_subtile_index == q_index) { make_causal(att_block_mma, att_block_mma, kittens::base_types::constants::zero()); } load(v_reg, v_smem[subtile]); // load v from shared into registers. rt_bf_1x4 &v_reg_col = swap_layout_inplace(v_reg); // this is a reference and the call has invalidated v_reg mma_AB(o_reg, att_block_mma, v_reg_col, o_reg); // mfma onto o_reg with the local attention@V matmul. } __syncthreads(); // we need to make sure all warps are done before we can start loading the next kv chunk } store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_reg, q_reg.cols); // write out o. compiler has an issue with register usage if d is made constexpr q_reg.rows :/ } } void causal_attend(torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor f, torch::Tensor o_small) { CHECK_INPUT(q); CHECK_INPUT(k); CHECK_INPUT(v); CHECK_INPUT(f); CHECK_INPUT(o_small); auto batch = q.size(0); auto head = q.size(1); auto n = q.size(2); auto d = q.size(3); auto dv = v.size(3); bool k_same = true; for(auto i = 0; i < 4; i++) { k_same &= q.size(i) == k.size(i); } // This is just a restriction of what we're doing now... TORCH_CHECK(k_same, "Q and K should be same size"); TORCH_CHECK(q.scalar_type() == c10::ScalarType::BFloat16, "Q is a Bfloat"); TORCH_CHECK(k.scalar_type() == c10::ScalarType::BFloat16, "K is a Bfloat"); TORCH_CHECK(v.scalar_type() == c10::ScalarType::BFloat16, "V is a Bfloat"); using H = __nv_bfloat16; using T = c10::BFloat16; const int workers = NUM_WORKERS; unsigned long mem_size = 2*workers*sizeof(st_bf_1x4); TORCH_CHECK(n % (workers*kittens::TILE_DIM) == 0, "The number of elements should be divisible the number of workers times stored fragments"); auto threads = workers * kittens::WARP_THREADS; //printf("[simple_compute_ker] Requesting %lu bytes of memory for %d (%d) workers\n", mem_size, workers, threads); CHECK_CUDA_ERROR(cudaFuncSetAttribute( attend_ker16, cudaFuncAttributeMaxDynamicSharedMemorySize, mem_size)); causal_attend_kernel<<>>((int)n, q.data_ptr(), k.data_ptr(), v.data_ptr(), f.data_ptr(), o_small.data_ptr()); CHECK_CUDA_ERROR(cudaDeviceSynchronize()); } //#include "harness.impl"