Skip to content

Instantly share code, notes, and snippets.

@chsasank
Last active February 6, 2024 13:17
Show Gist options
  • Select an option

  • Save chsasank/6061343f6589c8e3269e4d504bb64d2a to your computer and use it in GitHub Desktop.

Select an option

Save chsasank/6061343f6589c8e3269e4d504bb64d2a to your computer and use it in GitHub Desktop.

Revisions

  1. chsasank revised this gist Feb 6, 2024. 1 changed file with 0 additions and 1 deletion.
    1 change: 0 additions & 1 deletion build.sh
    Original file line number Diff line number Diff line change
    @@ -1 +0,0 @@
    icpx -fsycl -std=c++17 joint-matrix.cpp -o joint-matrix -lsycl -lOpenCL
  2. chsasank revised this gist Feb 6, 2024. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions build.sh
    Original file line number Diff line number Diff line change
    @@ -0,0 +1 @@
    icpx -fsycl -std=c++17 joint-matrix.cpp -o joint-matrix -lsycl -lOpenCL
  3. chsasank created this gist Feb 6, 2024.
    194 changes: 194 additions & 0 deletions joint_matrix.cpp
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,194 @@
    //==============================================================
    // Copyright © 2022 Intel Corporation
    //
    // SPDX-License-Identifier: MIT
    // =============================================================

    #include <iostream>
    #include <sycl/sycl.hpp>
    #include <chrono>

    // using joint_matrix = sycl::ext::oneapi::experimental::matrix;
    using use = sycl::ext::oneapi::experimental::matrix::use;
    using layout = sycl::ext::oneapi::experimental::matrix::layout;
    using bfloat16 = sycl::ext::oneapi::bfloat16;

    #define SG_SZ 8

    #define TM 8
    #define TN SG_SZ
    #define TK 16

    #define BF16_EPSILON 0.00781250

    template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
    private:
    T *mat;

    public:
    T *get_data() { return mat; }
    void set_data(T *data) { mat = data; }
    big_matrix(T *data) : mat(data) {}
    };

    template <typename T1, typename T2, size_t M, size_t N, size_t K>
    void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
    big_matrix<T2, K / 2, N * 2> &B) {
    // kernel begin
    size_t NDRangeM = M / TM;
    size_t NDRangeN = N / TN;
    sycl::buffer<bfloat16, 2> bufA(A.get_data(), sycl::range<2>(M, K));
    sycl::buffer<bfloat16, 2> bufB(B.get_data(), sycl::range<2>(K, N));
    sycl::buffer<float, 2> bufC((float *)C.get_data(), sycl::range<2>(M, N));

    sycl::queue q;
    q.submit([&](sycl::handler &cgh) {
    sycl::accessor accC(bufC, cgh, sycl::read_write, sycl::no_init);
    sycl::accessor accA(bufA, cgh, sycl::read_only);
    sycl::accessor accB(bufB, cgh, sycl::read_only);

    cgh.parallel_for(
    sycl::nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
    [=](sycl::nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]

    {
    // The joint matrix API has to be accessed by all the workitems in a
    // subgroup these functions will be called once by the subgroup no
    // code divergence between the workitems
    const auto global_idx = spmd_item.get_global_id(0);
    const auto global_idy = spmd_item.get_global_id(1);
    const auto sg_startx = global_idx - spmd_item.get_local_id(0);
    const auto sg_starty = global_idy - spmd_item.get_local_id(1);

    sycl::sub_group sg = spmd_item.get_sub_group();
    sycl::ext::oneapi::experimental::matrix::joint_matrix<
    sycl::sub_group, bfloat16, use::a, TM, TK, layout::row_major>
    sub_a;
    // For B, we assume B has been already VNNIed.
    sycl::ext::oneapi::experimental::matrix::joint_matrix<
    sycl::sub_group, bfloat16, use::b, TK, TN,
    sycl::ext::intel::experimental::matrix::layout::packed>
    sub_b;
    sycl::ext::oneapi::experimental::matrix::joint_matrix<
    sycl::sub_group, float, use::accumulator, TM, TN>
    sub_c;

    joint_matrix_load(sg, sub_c,
    accC.get_pointer() + (sg_startx * TM) * N +
    sg_starty / SG_SZ * TN,
    N, layout::row_major);
    for (int k = 0; k < K / TK; k += 1) { //
    joint_matrix_load(
    sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
    K);
    joint_matrix_load(sg, sub_b,
    accB.get_pointer() + (k * TK / 2) * (N * 2) +
    sg_starty / SG_SZ * TN * 2,
    N * 2);
    sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
    }
    joint_matrix_store(sg, sub_c,
    accC.get_pointer() + (sg_startx * TM) * N +
    sg_starty / SG_SZ * TN,
    N, layout::row_major);
    }); // parallel for
    }).wait();
    // kernel end
    }

    static constexpr size_t MATRIX_M = TM * 128 * 1;
    static constexpr size_t MATRIX_N = TN * 128 * 1;
    static constexpr size_t MATRIX_K = TK * 64 * 1;
    bfloat16 A[MATRIX_M][MATRIX_K];
    bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
    unsigned short Aref[MATRIX_M][MATRIX_K];
    unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2];
    float C[MATRIX_M][MATRIX_N];
    float D[MATRIX_M][MATRIX_N];

    float make_fp32(short x) {
    unsigned int y = x;
    y = y << 16;
    float *res = reinterpret_cast<float *>(&y);
    return *res;
    }

    unsigned short make_bf16(float x) {
    int *res = reinterpret_cast<int *>(&x);
    *res = *res >> 16;
    return (unsigned short)*res;
    }

    void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
    int K) {
    for (int m = 0; m < M; m++)
    for (int n = 0; n < N; n++) {
    for (int k = 0; k < K; k++) {
    short *va = (short *)(A_mem + m * K + k);
    short *vb = (short *)(B_mem + k * N + n);
    float acc = *((float *)(C_mem + m * N + n));
    for (int i = 0; i < 2; i++) {
    acc += (make_fp32(va[i]) * make_fp32(vb[i]));
    }
    *((float *)(C_mem + m * N + n)) = acc;
    }
    }
    }

    int main() {
    for (int i = 0; i < MATRIX_M; i++) {
    for (int j = 0; j < MATRIX_K; j++) {
    // bfloat16 is created using unsigned short since conversion from float to
    // bfloat16 is not supported on the host side yet
    A[i][j] = bfloat16(1.0f * (i + j));
    Aref[i][j] = make_bf16(1.0f * (i + j));
    }
    }
    for (int i = 0; i < MATRIX_K / 2; i++) {
    for (int j = 0; j < MATRIX_N * 2; j++) {
    B[i][j] = bfloat16(2.0f * i + 3.0f * j);
    Bref[i][j] = make_bf16(2.0f * i + 3.0f * j);
    }
    }
    for (int i = 0; i < MATRIX_M; i++) {
    for (int j = 0; j < MATRIX_N; j++) {
    C[i][j] = 1.0;
    D[i][j] = 1.0;
    }
    }

    big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
    big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
    big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
    big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);

    int num_trails = 20;
    auto start = std::chrono::steady_clock::now();
    for (int i = 0; i < num_trails; i++){
    matrix_multiply(MC, MA, MB);
    }
    auto end = std::chrono::steady_clock::now();
    auto total_time = std::chrono::duration<double>(end - start).count();
    std::cout << "time for gpu gemm " << MATRIX_M << " " << MATRIX_N << " " << MATRIX_K <<
    " for " << num_trails << " trails: " << total_time << std::endl;
    auto avg = total_time / num_trails;
    auto op_count = double(MATRIX_M) * double(MATRIX_N) * double(MATRIX_K) * 2;
    auto flops = op_count / avg;
    std::cout << "estimated flops: " << flops << std::endl;

    matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M,
    MATRIX_N, MATRIX_K / 2);

    bool res = true;
    for (int i = 0; i < MATRIX_M; i++) {
    for (int j = 0; j < MATRIX_N; j++) {
    C[i][j] = C[i][j] / num_trails;
    auto diff_r = 2 * fabs(C[i][j] - D[i][j]) / (fabs(C[i][j]) + fabs(D[i][j]));
    if (diff_r > 1e-2){
    res = false;
    }
    }
    }
    std::cout << (res ? "passed" : "failed") << std::endl;
    return !res;
    }