//============================================================== // Copyright © 2022 Intel Corporation // // SPDX-License-Identifier: MIT // ============================================================= #include #include #include // using joint_matrix = sycl::ext::oneapi::experimental::matrix; using use = sycl::ext::oneapi::experimental::matrix::use; using layout = sycl::ext::oneapi::experimental::matrix::layout; using bfloat16 = sycl::ext::oneapi::bfloat16; #define SG_SZ 8 #define TM 8 #define TN SG_SZ #define TK 16 #define BF16_EPSILON 0.00781250 template struct big_matrix { private: T *mat; public: T *get_data() { return mat; } void set_data(T *data) { mat = data; } big_matrix(T *data) : mat(data) {} }; template void matrix_multiply(big_matrix &C, big_matrix &A, big_matrix &B) { // kernel begin size_t NDRangeM = M / TM; size_t NDRangeN = N / TN; sycl::buffer bufA(A.get_data(), sycl::range<2>(M, K)); sycl::buffer bufB(B.get_data(), sycl::range<2>(K, N)); sycl::buffer bufC((float *)C.get_data(), sycl::range<2>(M, N)); sycl::queue q; q.submit([&](sycl::handler &cgh) { sycl::accessor accC(bufC, cgh, sycl::read_write, sycl::no_init); sycl::accessor accA(bufA, cgh, sycl::read_only); sycl::accessor accB(bufB, cgh, sycl::read_only); cgh.parallel_for( sycl::nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), [=](sycl::nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { // The joint matrix API has to be accessed by all the workitems in a // subgroup these functions will be called once by the subgroup no // code divergence between the workitems const auto global_idx = spmd_item.get_global_id(0); const auto global_idy = spmd_item.get_global_id(1); const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); sycl::sub_group sg = spmd_item.get_sub_group(); sycl::ext::oneapi::experimental::matrix::joint_matrix< sycl::sub_group, bfloat16, use::a, TM, TK, layout::row_major> sub_a; // For B, we assume B has been already VNNIed. sycl::ext::oneapi::experimental::matrix::joint_matrix< sycl::sub_group, bfloat16, use::b, TK, TN, sycl::ext::intel::experimental::matrix::layout::packed> sub_b; sycl::ext::oneapi::experimental::matrix::joint_matrix< sycl::sub_group, float, use::accumulator, TM, TN> sub_c; joint_matrix_load(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, K); joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); }); // parallel for }).wait(); // kernel end } static constexpr size_t MATRIX_M = TM * 128 * 1; static constexpr size_t MATRIX_N = TN * 128 * 1; static constexpr size_t MATRIX_K = TK * 64 * 1; bfloat16 A[MATRIX_M][MATRIX_K]; bfloat16 B[MATRIX_K / 2][MATRIX_N * 2]; unsigned short Aref[MATRIX_M][MATRIX_K]; unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2]; float C[MATRIX_M][MATRIX_N]; float D[MATRIX_M][MATRIX_N]; float make_fp32(short x) { unsigned int y = x; y = y << 16; float *res = reinterpret_cast(&y); return *res; } unsigned short make_bf16(float x) { int *res = reinterpret_cast(&x); *res = *res >> 16; return (unsigned short)*res; } void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, int K) { for (int m = 0; m < M; m++) for (int n = 0; n < N; n++) { for (int k = 0; k < K; k++) { short *va = (short *)(A_mem + m * K + k); short *vb = (short *)(B_mem + k * N + n); float acc = *((float *)(C_mem + m * N + n)); for (int i = 0; i < 2; i++) { acc += (make_fp32(va[i]) * make_fp32(vb[i])); } *((float *)(C_mem + m * N + n)) = acc; } } } int main() { for (int i = 0; i < MATRIX_M; i++) { for (int j = 0; j < MATRIX_K; j++) { // bfloat16 is created using unsigned short since conversion from float to // bfloat16 is not supported on the host side yet A[i][j] = bfloat16(1.0f * (i + j)); Aref[i][j] = make_bf16(1.0f * (i + j)); } } for (int i = 0; i < MATRIX_K / 2; i++) { for (int j = 0; j < MATRIX_N * 2; j++) { B[i][j] = bfloat16(2.0f * i + 3.0f * j); Bref[i][j] = make_bf16(2.0f * i + 3.0f * j); } } for (int i = 0; i < MATRIX_M; i++) { for (int j = 0; j < MATRIX_N; j++) { C[i][j] = 1.0; D[i][j] = 1.0; } } big_matrix MC((float *)&C); big_matrix MD((float *)&D); big_matrix MA((bfloat16 *)&A); big_matrix MB((bfloat16 *)&B); int num_trails = 20; auto start = std::chrono::steady_clock::now(); for (int i = 0; i < num_trails; i++){ matrix_multiply(MC, MA, MB); } auto end = std::chrono::steady_clock::now(); auto total_time = std::chrono::duration(end - start).count(); std::cout << "time for gpu gemm " << MATRIX_M << " " << MATRIX_N << " " << MATRIX_K << " for " << num_trails << " trails: " << total_time << std::endl; auto avg = total_time / num_trails; auto op_count = double(MATRIX_M) * double(MATRIX_N) * double(MATRIX_K) * 2; auto flops = op_count / avg; std::cout << "estimated flops: " << flops << std::endl; matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M, MATRIX_N, MATRIX_K / 2); bool res = true; for (int i = 0; i < MATRIX_M; i++) { for (int j = 0; j < MATRIX_N; j++) { C[i][j] = C[i][j] / num_trails; auto diff_r = 2 * fabs(C[i][j] - D[i][j]) / (fabs(C[i][j]) + fabs(D[i][j])); if (diff_r > 1e-2){ res = false; } } } std::cout << (res ? "passed" : "failed") << std::endl; return !res; }