Skip to content

Instantly share code, notes, and snippets.

@vmiheer
Created December 19, 2024 16:56
Show Gist options
  • Select an option

  • Save vmiheer/06c23a25e37e69f3de05c9d031e1512f to your computer and use it in GitHub Desktop.

Select an option

Save vmiheer/06c23a25e37e69f3de05c9d031e1512f to your computer and use it in GitHub Desktop.

Revisions

  1. vmiheer created this gist Dec 19, 2024.
    109 changes: 109 additions & 0 deletions CMakeLists.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,109 @@
    cmake_minimum_required(VERSION 3.15)
    project(dump_partitions)
    include(${CMAKE_CURRENT_SOURCE_DIR}/../cmake/CPM.cmake)
    CPMAddPackage("gh:fmtlib/fmt#11.0.2")
    CPMAddPackage("gh:p-ranav/argparse#v3.1")

    file(WRITE "${CMAKE_BINARY_DIR}/cleanupxformdata.cmake"
    [=[
    file(READ "${SOURCE}" TEXT)
    string(REGEX REPLACE "\nmodule {\n" "\n" TEXT "${TEXT}")
    string(REGEX REPLACE "\n[ ]+module attributes .*" "\n" TEXT "${TEXT}")
    file(WRITE "${TARGET}" "${TEXT}")
    ]=])

    option(USE_KOKKOS "Use Kokkos" ON)

    find_package(MLIR REQUIRED)
    find_package(LLVM REQUIRED)

    if(USE_KOKKOS)
    find_package(Kokkos REQUIRED)
    file(WRITE "${CMAKE_BINARY_DIR}/makeExtern.cmake"
    [=[
    file(READ "${SOURCE}" TEXT)
    string(REGEX REPLACE "void[*] sparse_mha.*{" "extern \"C\" \\0" TEXT "${TEXT}")
    string(REGEX REPLACE "void[*] pte_[^{]*{" "extern \"C\" \\0" TEXT "${TEXT}")
    file(WRITE "${TARGET}" "${TEXT}")
    ]=])
    endif()

    add_library(mmio mmio.c)

    function(compile_mlir SRC_FILE)
    add_custom_command(
    OUTPUT
    ${CMAKE_CURRENT_BINARY_DIR}/${SRC_FILE}.o
    COMMAND
    mlir-opt --sparsifier=\"enable-runtime-library=true\"
    ${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FILE}.mlir --sparsifier -o
    ${SRC_FILE}.llvm.mlir
    COMMAND
    mlir-translate -mlir-to-llvmir ${SRC_FILE}.llvm.mlir -o ${SRC_FILE}.llvm
    COMMAND
    llc --relocation-model=pic ${SRC_FILE}.llvm -o ${SRC_FILE}.s
    COMMAND
    as -g -o ${SRC_FILE}.o ${SRC_FILE}.s
    DEPENDS
    ${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FILE}.mlir
    )
    add_library(${SRC_FILE} OBJECT IMPORTED)
    set_target_properties(${SRC_FILE} PROPERTIES IMPORTED_OBJECTS
    ${CMAKE_CURRENT_BINARY_DIR}/${SRC_FILE}.o)
    endfunction()

    function(compile_mlir_to_kokkos SRC_FILE)
    add_custom_command(
    OUTPUT
    ${CMAKE_CURRENT_BINARY_DIR}/${SRC_FILE}.cpp
    COMMAND
    lapis-opt --sparse-compiler-kokkos='pt-backend=mpi parallelization-strategy=dense-any-loop'
    ${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FILE}.mlir -o ${SRC_FILE}.scf.mlir
    COMMAND
    lapis-translate -sparse-mlir-to-kokkos ${SRC_FILE}.scf.mlir -o
    ${SRC_FILE}.cpp.temp
    COMMAND
    # cp ${SRC_FILE}.cpp.temp ${SRC_FILE}.cpp
    ${CMAKE_COMMAND} -DSOURCE=${SRC_FILE}.cpp.temp -DTARGET=${SRC_FILE}.cpp
    -P makeExtern.cmake
    DEPENDS
    ${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FILE}.mlir
    )
    add_library(${SRC_FILE} ${CMAKE_CURRENT_BINARY_DIR}/${SRC_FILE}.cpp)
    target_compile_definitions(${SRC_FILE} PUBLIC USE_KOKKOS=1)
    # target_compile_options(${SRC_FILE} PRIVATE -g -ggdb)
    target_link_libraries(${SRC_FILE} PRIVATE Kokkos::kokkos fmt::fmt)
    endfunction()

    set(MLIR_INPUT_NAME "bspmm")
    if (USE_KOKKOS)
    compile_mlir_to_kokkos(${MLIR_INPUT_NAME})
    else()
    compile_mlir(${MLIR_INPUT_NAME})
    endif()

    add_executable(dump_partitions.part_tensor wrapper.cpp)


    target_include_directories(dump_partitions.part_tensor
    PRIVATE
    ${MLIR_INCLUDE_DIRS}
    ${LLVM_INCLUDE_DIRS}
    )
    target_link_libraries(dump_partitions.part_tensor
    PRIVATE
    ${MLIR_INPUT_NAME}
    MLIRExecutionEngine
    MLIRSparseTensorUtils
    MLIRSparseTensorRuntime
    mlir_c_runner_utils
    mlir_runner_utils
    argparse
    mmio
    fmt::fmt
    )
    target_compile_options(dump_partitions.part_tensor PUBLIC -fno-rtti)
    set_target_properties(dump_partitions.part_tensor PROPERTIES
    CXX_STANDARD 23
    CXX_EXTENSIONS OFF
    )
    100 changes: 100 additions & 0 deletions CMakePresets.json
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,100 @@
    {
    "version": 6,
    "cmakeMinimumRequired": {
    "major": 3,
    "minor": 23,
    "patch": 0
    },
    "configurePresets": [
    {
    "name": "default-kokkos",
    "displayName": "default",
    "description": "Default local preset",
    "generator": "Ninja",
    "binaryDir": "${sourceDir}/localbuild",
    "cacheVariables": {
    "CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
    "CMAKE_BUILD_TYPE": "Release",
    "USE_KOKKOS": "ON"
    }
    },
    {
    "name": "default-llvm",
    "displayName": "default",
    "description": "Default preset",
    "generator": "Ninja",
    "binaryDir": "${sourceDir}/build",
    "cacheVariables": {
    "CMAKE_C_COMPILER": "/usr/bin/gcc-13",
    "CMAKE_CXX_COMPILER": "/usr/bin/g++-13",
    "CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
    "CMAKE_BUILD_TYPE": "Debug",
    "USE_KOKKOS": "OFF"
    }
    }
    ],
    "buildPresets": [
    {
    "name": "default-llvm",
    "displayName": "Default",
    "description": "Default preset",
    "configurePreset": "default-llvm"
    },
    {
    "name": "default",
    "displayName": "Default",
    "description": "Default preset",
    "configurePreset": "default-kokkos"
    },
    {
    "name": "notchpeak",
    "displayName": "notchpeak",
    "description": "Default preset for notchpeak",
    "configurePreset": "default-kokkos"
    }
    ],
    "workflowPresets": [
    {
    "name": "default-kokkos",
    "description": "default",
    "steps": [
    {
    "type": "configure",
    "name": "default-kokkos"
    },
    {
    "type": "build",
    "name": "default"
    }
    ]
    },
    {
    "name": "default-llvm",
    "description": "default",
    "steps": [
    {
    "type": "configure",
    "name": "default-llvm"
    },
    {
    "type": "build",
    "name": "default-llvm"
    }
    ]
    },
    {
    "name": "notchpeak",
    "description": "default workflow for notchpeak",
    "steps": [
    {
    "type": "configure",
    "name": "default-kokkos"
    },
    {
    "type": "build",
    "name": "notchpeak"
    }
    ]
    }
    ]
    }
    98 changes: 98 additions & 0 deletions a.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,98 @@
    #!/usr/bin/env python3

    import numpy as np
    from pathlib import Path
    import torch
    import dgl.sparse as dglsp
    from copy import deepcopy
    from itertools import product, repeat
    from more_itertools import take
    import argparse
    import sys


    def is_interactive():
    import __main__ as main

    return not hasattr(main, "__file__")


    def tensor_from_file(file_name: Path, dtype=np.float32, read_size=-1):
    if read_size > 0:
    a = open(file_name, "rb").read(read_size)
    else:
    a = open(file_name, "rb").read()
    n = np.frombuffer(a, dtype=dtype)
    return torch.from_numpy(deepcopy(n))


    if is_interactive():
    sys.argv = [
    "a.py",
    "banded_11_r2.coordinates.bin",
    "11",
    "4",
    "2",
    ]
    # sys.argv = [
    # "a.py",
    # "banded_4_r2.coordinates.bin",
    # "4",
    # "1",
    # "1",
    # "/scratch/general/vast/u1290058/LAPIS_Workspace/snl-utah/sandboxes/mseyden/graph-attention/dgl/arxiv.coordinates.bin",
    # "169344",
    # "1",
    # "1",
    # ]

    if len(sys.argv) != 5:
    print(f"Usage: python3 {sys.argv[0]} <dataset_name> <N> <Dh> <Nh>")
    exit(1)

    N, Dh, Nh = map(int, sys.argv[2:])

    A = tensor_from_file(Path(sys.argv[1]), dtype=np.int64)
    nnzCount = A.shape[0] // 2
    info_file_name = sys.argv[1].split(".")[0] + ".coordinates.info"
    Nv, _, Nnz = map(int, open(info_file_name).read().split())
    print(f"Nv: {Nv}, N: {N}, Nnz: {Nnz}, nnzCount: {nnzCount}")
    print(f"N: {N}, Dh: {Dh}, Nh: {Nh}")
    # assert N == Nv
    edgeFeatPath = sys.argv[1].split(".")[0] + ".edge.data.bin"
    edgeData = tensor_from_file(edgeFeatPath, read_size=(Nnz * Nh * 4)).reshape(Nnz, Nh)
    print("EdgeData[0]: ", edgeData[0])
    A = dglsp.spmatrix(
    A.reshape(Nnz, 2).transpose(1, 0),
    # val=torch.tensor([1.0] * Nnz),
    val=edgeData,
    shape=(N, N),
    )
    featPath = sys.argv[1].split(".")[0] + ".vert.data.bin"
    # Q = tensor_from_file(featPath, read_size=(N * Dh * Nh * 4)).reshape(N, Dh, Nh)
    # K = tensor_from_file(featPath, read_size=(N * Dh * Nh * 4)).reshape(N, Dh, Nh)
    V = tensor_from_file(featPath, read_size=(N * Dh * Nh * 4)).reshape(N, Dh, Nh)

    spmm_out = dglsp.bspmm(A, V)
    print(spmm_out.shape)
    outPath = sys.argv[1].split(".")[0].split("/")[-1] + ".res"
    out = tensor_from_file(outPath, read_size=(N * Dh * Nh * 4)).reshape(N, Dh, Nh)

    # sddmm_out = dglsp.bsddmm(A, Q, K.transpose(1, 0)).softmax(dim=1)
    # spmm_out = dglsp.bspmm(sddmm_out, V)
    ## sddmm_out = dglsp.bsddmm(A, Q, K.transpose(1, 0)).to_dense().softmax(dim=1)
    ## spmm_out = torch.mm(torch.squeeze(sddmm_out), torch.squeeze(V, -1))
    ## # print(spmm_out)
    ## sddmm_out = dglsp.bsddmm(A, Q, K.transpose(1, 0)).softmax(dim=1)
    ## # print(sddmm_out)
    ## spmm_out = dglsp.bspmm(sddmm_out, V)
    ## print(sddmm_out)
    ## exit(1)
    ## spmm_out = torch.mm(sddmm_out, V)

    if not (torch.allclose(out, spmm_out)):
    print(spmm_out - out)
    print(spmm_out)
    print(out)
    exit(1)
    exit(0)
    79 changes: 79 additions & 0 deletions bspmm.mlir
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,79 @@
    #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
    #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
    #csrv = #sparse_tensor.encoding<{ map = (d0, d1, d2) ->
    (d0 : dense, d1 : compressed, d2 : dense) }>
    #dense = #sparse_tensor.encoding<{ map = (d0, d1) ->
    (d0 : dense, d1 : dense) }>
    #densev = #sparse_tensor.encoding<{ map = (d0, d1, d2) ->
    (d0 : dense, d1 : dense, d2 : dense) }>
    #csr = #sparse_tensor.encoding<{ map = (d0, d1) ->
    (d0 : dense, d1 : compressed) }>
    #partCsr = #part_tensor.encoding<{
    partConst = 1,
    sparseAttributes = #csr
    }>
    #partDensev = #part_tensor.encoding<{
    partConst = 1,
    sparseAttributes = #densev
    }>
    #bsddmm_map = {
    indexing_maps = [
    affine_map<(n1, n2, dh, nh) -> (n1, dh, nh)>, // q (in)
    affine_map<(n1, n2, dh, nh) -> (n2, dh, nh)>, // k (in)
    affine_map<(n1, n2, dh, nh) -> (n1, n2)>, // A (in)
    affine_map<(n1, n2, dh, nh) -> (n1, n2, nh)> // attn (out)
    ],
    iterator_types = ["parallel", "parallel", "reduction", "parallel"],
    doc = "attn(n1, n2, nh) = q(n1, dh, nh) * k(n2, dh, nh)"
    }
    #bspmm_map = {
    indexing_maps = [
    affine_map<(n1, n2, dh, nh) -> (n1, n2, nh)>, // attn (in)
    affine_map<(n1, n2, dh, nh) -> (n2, dh, nh)>, // v (in)
    affine_map<(n1, n2, dh, nh) -> (n1, dh, nh)> // out (out)
    ],
    iterator_types = ["parallel", "parallel", "reduction", "parallel"],
    doc = "out(n1, dh, nh) = attn(n1, n2, nh) * v(n2, dh, nh)"
    }

    module {
    func.func @pte_local_bspmm(%A: tensor<?x?x?xf32, #csrv>,
    %B: tensor<?x?x?xf32, #densev>) -> tensor<?x?x?xf32, #densev>
    {
    %c0_index = arith.constant 0 : index
    %c1_index = arith.constant 1 : index
    %c2_index = arith.constant 2 : index
    %c3_index = arith.constant 3 : index
    %c4_index = arith.constant 4 : index
    %c5_index = arith.constant 5 : index
    %c6_index = arith.constant 6 : index
    %c9_index = arith.constant 9 : index

    %c0_f32 = arith.constant 0.0 : f32

    %N1 = tensor.dim %A, %c0_index : tensor<?x?x?xf32, #csrv>
    %N2 = tensor.dim %A, %c1_index : tensor<?x?x?xf32, #csrv>
    %nh = tensor.dim %A, %c2_index : tensor<?x?x?xf32, #csrv>
    %dh = tensor.dim %B, %c1_index : tensor<?x?x?xf32, #densev>
    %spmm_in0 = tensor.empty (%N1, %dh, %nh) : tensor<?x?x?xf32, #densev>
    %spmm_in1 = linalg.fill ins(%c0_f32: f32)
    outs(%spmm_in0 : tensor<?x?x?xf32, #densev>)
    -> tensor<?x?x?xf32, #densev>
    %attn4 = linalg.generic #bspmm_map
    ins(%A, %B: tensor<?x?x?xf32, #csrv>, tensor<?x?x?xf32, #densev>)
    outs(%spmm_in1: tensor<?x?x?xf32, #densev>) {
    ^bb0(%q: f32, %k: f32, %attn: f32): // no predecessors
    %0 = arith.mulf %q, %k : f32
    %1 = arith.addf %0, %attn: f32
    linalg.yield %1 : f32
    } -> tensor<?x?x?xf32, #densev>
    // %attn4 = tensor.cast %attn3: tensor<?x?x?xf32, #csrv> to tensor<?x?x?xf32>
    return %attn4 : tensor<?x?x?xf32, #densev>
    }
    }


    // Local Variables:
    // rmsbolt-command: "lapis-opt --sparse-compiler-kokkos='pt-backend=mpi'"
    // rmsbolt-automatic-recompile: on-save
    // End:
    14 changes: 14 additions & 0 deletions cat_tensor.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,14 @@
    #!/usr/bin/env python3

    import numpy as np
    import sys

    if len(sys.argv) == 3 and sys.argv[2] == "-f":
    a = np.frombuffer(open(sys.argv[1], "rb").read(), dtype=np.float32)
    else:
    a = (
    np.frombuffer(open(sys.argv[1], "rb").read(), dtype=np.int64)
    .reshape(-1, 2)
    .transpose(1, 0)
    )
    print(a)
    511 changes: 511 additions & 0 deletions mmio.c
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,511 @@
    /*
    * Matrix Market I/O library for ANSI C
    *
    * See http://math.nist.gov/MatrixMarket for details.
    *
    *
    */


    #include <stdio.h>
    #include <string.h>
    #include <stdlib.h>
    #include <ctype.h>

    #include "mmio.h"

    int mm_read_unsymmetric_sparse(const char *fname, int *M_, int *N_, int *nz_,
    double **val_, int **I_, int **J_)
    {
    FILE *f;
    MM_typecode matcode;
    int M, N, nz;
    int i;
    double *val;
    int *I, *J;

    if ((f = fopen(fname, "r")) == NULL)
    return -1;


    if (mm_read_banner(f, &matcode) != 0)
    {
    printf("mm_read_unsymetric: Could not process Matrix Market banner ");
    printf(" in file [%s]\n", fname);
    return -1;
    }



    if ( !(mm_is_real(matcode) && mm_is_matrix(matcode) &&
    mm_is_sparse(matcode)))
    {
    fprintf(stderr, "Sorry, this application does not support ");
    fprintf(stderr, "Market Market type: [%s]\n",
    mm_typecode_to_str(matcode));
    return -1;
    }

    /* find out size of sparse matrix: M, N, nz .... */

    if (mm_read_mtx_crd_size(f, &M, &N, &nz) !=0)
    {
    fprintf(stderr, "read_unsymmetric_sparse(): could not parse matrix size.\n");
    return -1;
    }

    *M_ = M;
    *N_ = N;
    *nz_ = nz;

    /* reseve memory for matrices */

    I = (int *) malloc(nz * sizeof(int));
    J = (int *) malloc(nz * sizeof(int));
    val = (double *) malloc(nz * sizeof(double));

    *val_ = val;
    *I_ = I;
    *J_ = J;

    /* NOTE: when reading in doubles, ANSI C requires the use of the "l" */
    /* specifier as in "%lg", "%lf", "%le", otherwise errors will occur */
    /* (ANSI C X3.159-1989, Sec. 4.9.6.2, p. 136 lines 13-15) */

    for (i=0; i<nz; i++)
    {
    fscanf(f, "%d %d %lg\n", &I[i], &J[i], &val[i]);
    I[i]--; /* adjust from 1-based to 0-based */
    J[i]--;
    }
    fclose(f);

    return 0;
    }

    int mm_is_valid(MM_typecode matcode)
    {
    if (!mm_is_matrix(matcode)) return 0;
    if (mm_is_dense(matcode) && mm_is_pattern(matcode)) return 0;
    if (mm_is_real(matcode) && mm_is_hermitian(matcode)) return 0;
    if (mm_is_pattern(matcode) && (mm_is_hermitian(matcode) ||
    mm_is_skew(matcode))) return 0;
    return 1;
    }

    int mm_read_banner(FILE *f, MM_typecode *matcode)
    {
    char line[MM_MAX_LINE_LENGTH];
    char banner[MM_MAX_TOKEN_LENGTH];
    char mtx[MM_MAX_TOKEN_LENGTH];
    char crd[MM_MAX_TOKEN_LENGTH];
    char data_type[MM_MAX_TOKEN_LENGTH];
    char storage_scheme[MM_MAX_TOKEN_LENGTH];
    char *p;


    mm_clear_typecode(matcode);

    if (fgets(line, MM_MAX_LINE_LENGTH, f) == NULL)
    return MM_PREMATURE_EOF;

    if (sscanf(line, "%s %s %s %s %s", banner, mtx, crd, data_type,
    storage_scheme) != 5)
    return MM_PREMATURE_EOF;

    for (p=mtx; *p!='\0'; *p=tolower(*p),p++); /* convert to lower case */
    for (p=crd; *p!='\0'; *p=tolower(*p),p++);
    for (p=data_type; *p!='\0'; *p=tolower(*p),p++);
    for (p=storage_scheme; *p!='\0'; *p=tolower(*p),p++);

    /* check for banner */
    if (strncmp(banner, MatrixMarketBanner, strlen(MatrixMarketBanner)) != 0)
    return MM_NO_HEADER;

    /* first field should be "mtx" */
    if (strcmp(mtx, MM_MTX_STR) != 0)
    return MM_UNSUPPORTED_TYPE;
    mm_set_matrix(matcode);


    /* second field describes whether this is a sparse matrix (in coordinate
    storgae) or a dense array */


    if (strcmp(crd, MM_SPARSE_STR) == 0)
    mm_set_sparse(matcode);
    else
    if (strcmp(crd, MM_DENSE_STR) == 0)
    mm_set_dense(matcode);
    else
    return MM_UNSUPPORTED_TYPE;


    /* third field */

    if (strcmp(data_type, MM_REAL_STR) == 0)
    mm_set_real(matcode);
    else
    if (strcmp(data_type, MM_COMPLEX_STR) == 0)
    mm_set_complex(matcode);
    else
    if (strcmp(data_type, MM_PATTERN_STR) == 0)
    mm_set_pattern(matcode);
    else
    if (strcmp(data_type, MM_INT_STR) == 0)
    mm_set_integer(matcode);
    else
    return MM_UNSUPPORTED_TYPE;


    /* fourth field */

    if (strcmp(storage_scheme, MM_GENERAL_STR) == 0)
    mm_set_general(matcode);
    else
    if (strcmp(storage_scheme, MM_SYMM_STR) == 0)
    mm_set_symmetric(matcode);
    else
    if (strcmp(storage_scheme, MM_HERM_STR) == 0)
    mm_set_hermitian(matcode);
    else
    if (strcmp(storage_scheme, MM_SKEW_STR) == 0)
    mm_set_skew(matcode);
    else
    return MM_UNSUPPORTED_TYPE;


    return 0;
    }

    int mm_write_mtx_crd_size(FILE *f, int M, int N, int nz)
    {
    if (fprintf(f, "%d %d %d\n", M, N, nz) != 3)
    return MM_COULD_NOT_WRITE_FILE;
    else
    return 0;
    }

    int mm_read_mtx_crd_size(FILE *f, int *M, int *N, int *nz )
    {
    char line[MM_MAX_LINE_LENGTH];
    int num_items_read;

    /* set return null parameter values, in case we exit with errors */
    *M = *N = *nz = 0;

    /* now continue scanning until you reach the end-of-comments */
    do
    {
    if (fgets(line,MM_MAX_LINE_LENGTH,f) == NULL)
    return MM_PREMATURE_EOF;
    }while (line[0] == '%');

    /* line[] is either blank or has M,N, nz */
    if (sscanf(line, "%d %d %d", M, N, nz) == 3)
    return 0;

    else
    do
    {
    num_items_read = fscanf(f, "%d %d %d", M, N, nz);
    if (num_items_read == EOF) return MM_PREMATURE_EOF;
    }
    while (num_items_read != 3);

    return 0;
    }


    int mm_read_mtx_array_size(FILE *f, int *M, int *N)
    {
    char line[MM_MAX_LINE_LENGTH];
    int num_items_read;
    /* set return null parameter values, in case we exit with errors */
    *M = *N = 0;

    /* now continue scanning until you reach the end-of-comments */
    do
    {
    if (fgets(line,MM_MAX_LINE_LENGTH,f) == NULL)
    return MM_PREMATURE_EOF;
    }while (line[0] == '%');

    /* line[] is either blank or has M,N, nz */
    if (sscanf(line, "%d %d", M, N) == 2)
    return 0;

    else /* we have a blank line */
    do
    {
    num_items_read = fscanf(f, "%d %d", M, N);
    if (num_items_read == EOF) return MM_PREMATURE_EOF;
    }
    while (num_items_read != 2);

    return 0;
    }

    int mm_write_mtx_array_size(FILE *f, int M, int N)
    {
    if (fprintf(f, "%d %d\n", M, N) != 2)
    return MM_COULD_NOT_WRITE_FILE;
    else
    return 0;
    }



    /*-------------------------------------------------------------------------*/

    /******************************************************************/
    /* use when I[], J[], and val[]J, and val[] are already allocated */
    /******************************************************************/

    int mm_read_mtx_crd_data(FILE *f, int M, int N, int nz, int I[], int J[],
    double val[], MM_typecode matcode)
    {
    int i;
    if (mm_is_complex(matcode))
    {
    for (i=0; i<nz; i++)
    if (fscanf(f, "%d %d %lg %lg", &I[i], &J[i], &val[2*i], &val[2*i+1])
    != 4) return MM_PREMATURE_EOF;
    }
    else if (mm_is_real(matcode))
    {
    for (i=0; i<nz; i++)
    {
    if (fscanf(f, "%d %d %lg\n", &I[i], &J[i], &val[i])
    != 3) return MM_PREMATURE_EOF;

    }
    }

    else if (mm_is_pattern(matcode))
    {
    for (i=0; i<nz; i++)
    if (fscanf(f, "%d %d", &I[i], &J[i])
    != 2) return MM_PREMATURE_EOF;
    }
    else
    return MM_UNSUPPORTED_TYPE;

    return 0;

    }

    int mm_read_mtx_crd_entry(FILE *f, int *I, int *J,
    double *real, double *imag, MM_typecode matcode)
    {
    if (mm_is_complex(matcode))
    {
    if (fscanf(f, "%d %d %lg %lg", I, J, real, imag)
    != 4) return MM_PREMATURE_EOF;
    }
    else if (mm_is_real(matcode))
    {
    if (fscanf(f, "%d %d %lg\n", I, J, real)
    != 3) return MM_PREMATURE_EOF;

    }

    else if (mm_is_pattern(matcode))
    {
    if (fscanf(f, "%d %d", I, J) != 2) return MM_PREMATURE_EOF;
    }
    else
    return MM_UNSUPPORTED_TYPE;

    return 0;

    }


    /************************************************************************
    mm_read_mtx_crd() fills M, N, nz, array of values, and return
    type code, e.g. 'MCRS'
    if matrix is complex, values[] is of size 2*nz,
    (nz pairs of real/imaginary values)
    ************************************************************************/

    int mm_read_mtx_crd(char *fname, int *M, int *N, int *nz, int **I, int **J,
    double **val, MM_typecode *matcode)
    {
    int ret_code;
    FILE *f;

    if (strcmp(fname, "stdin") == 0) f=stdin;
    else
    if ((f = fopen(fname, "r")) == NULL)
    return MM_COULD_NOT_READ_FILE;


    if ((ret_code = mm_read_banner(f, matcode)) != 0)
    return ret_code;

    if (!(mm_is_valid(*matcode) && mm_is_sparse(*matcode) &&
    mm_is_matrix(*matcode)))
    return MM_UNSUPPORTED_TYPE;

    if ((ret_code = mm_read_mtx_crd_size(f, M, N, nz)) != 0)
    return ret_code;


    *I = (int *) malloc(*nz * sizeof(int));
    *J = (int *) malloc(*nz * sizeof(int));
    *val = NULL;

    if (mm_is_complex(*matcode))
    {
    *val = (double *) malloc(*nz * 2 * sizeof(double));
    ret_code = mm_read_mtx_crd_data(f, *M, *N, *nz, *I, *J, *val,
    *matcode);
    if (ret_code != 0) return ret_code;
    }
    else if (mm_is_real(*matcode))
    {
    *val = (double *) malloc(*nz * sizeof(double));
    ret_code = mm_read_mtx_crd_data(f, *M, *N, *nz, *I, *J, *val,
    *matcode);
    if (ret_code != 0) return ret_code;
    }

    else if (mm_is_pattern(*matcode))
    {
    ret_code = mm_read_mtx_crd_data(f, *M, *N, *nz, *I, *J, *val,
    *matcode);
    if (ret_code != 0) return ret_code;
    }

    if (f != stdin) fclose(f);
    return 0;
    }

    int mm_write_banner(FILE *f, MM_typecode matcode)
    {
    char *str = mm_typecode_to_str(matcode);
    int ret_code;

    ret_code = fprintf(f, "%s %s\n", MatrixMarketBanner, str);
    free(str);
    if (ret_code !=2 )
    return MM_COULD_NOT_WRITE_FILE;
    else
    return 0;
    }

    int mm_write_mtx_crd(char fname[], int M, int N, int nz, int I[], int J[],
    double val[], MM_typecode matcode)
    {
    FILE *f;
    int i;

    if (strcmp(fname, "stdout") == 0)
    f = stdout;
    else
    if ((f = fopen(fname, "w")) == NULL)
    return MM_COULD_NOT_WRITE_FILE;

    /* print banner followed by typecode */
    fprintf(f, "%s ", MatrixMarketBanner);
    fprintf(f, "%s\n", mm_typecode_to_str(matcode));

    /* print matrix sizes and nonzeros */
    fprintf(f, "%d %d %d\n", M, N, nz);

    /* print values */
    if (mm_is_pattern(matcode))
    for (i=0; i<nz; i++)
    fprintf(f, "%d %d\n", I[i], J[i]);
    else
    if (mm_is_real(matcode))
    for (i=0; i<nz; i++)
    fprintf(f, "%d %d %20.16g\n", I[i], J[i], val[i]);
    else
    if (mm_is_complex(matcode))
    for (i=0; i<nz; i++)
    fprintf(f, "%d %d %20.16g %20.16g\n", I[i], J[i], val[2*i],
    val[2*i+1]);
    else
    {
    if (f != stdout) fclose(f);
    return MM_UNSUPPORTED_TYPE;
    }

    if (f !=stdout) fclose(f);

    return 0;
    }


    /**
    * Create a new copy of a string s. mm_strdup() is a common routine, but
    * not part of ANSI C, so it is included here. Used by mm_typecode_to_str().
    *
    */
    char *mm_strdup(const char *s)
    {
    int len = strlen(s);
    char *s2 = (char *) malloc((len+1)*sizeof(char));
    return strcpy(s2, s);
    }

    char *mm_typecode_to_str(MM_typecode matcode)
    {
    char buffer[MM_MAX_LINE_LENGTH];
    char *types[4];
    char *mm_strdup(const char *);
    int error =0;

    /* check for MTX type */
    if (mm_is_matrix(matcode))
    types[0] = MM_MTX_STR;
    else
    error=1;

    /* check for CRD or ARR matrix */
    if (mm_is_sparse(matcode))
    types[1] = MM_SPARSE_STR;
    else
    if (mm_is_dense(matcode))
    types[1] = MM_DENSE_STR;
    else
    return NULL;

    /* check for element data type */
    if (mm_is_real(matcode))
    types[2] = MM_REAL_STR;
    else
    if (mm_is_complex(matcode))
    types[2] = MM_COMPLEX_STR;
    else
    if (mm_is_pattern(matcode))
    types[2] = MM_PATTERN_STR;
    else
    if (mm_is_integer(matcode))
    types[2] = MM_INT_STR;
    else
    return NULL;


    /* check for symmetry type */
    if (mm_is_general(matcode))
    types[3] = MM_GENERAL_STR;
    else
    if (mm_is_symmetric(matcode))
    types[3] = MM_SYMM_STR;
    else
    if (mm_is_hermitian(matcode))
    types[3] = MM_HERM_STR;
    else
    if (mm_is_skew(matcode))
    types[3] = MM_SKEW_STR;
    else
    return NULL;

    sprintf(buffer,"%s %s %s %s", types[0], types[1], types[2], types[3]);
    return mm_strdup(buffer);

    }
    133 changes: 133 additions & 0 deletions mmio.h
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,133 @@
    /*
    * Matrix Market I/O library for ANSI C
    *
    * See http://math.nist.gov/MatrixMarket for details.
    *
    *
    */

    #ifndef MM_IO_H
    #define MM_IO_H

    #define MM_MAX_LINE_LENGTH 1025
    #define MatrixMarketBanner "%%MatrixMarket"
    #define MM_MAX_TOKEN_LENGTH 64

    typedef char MM_typecode[4];

    char *mm_typecode_to_str(MM_typecode matcode);

    int mm_read_banner(FILE *f, MM_typecode *matcode);
    int mm_read_mtx_crd_size(FILE *f, int *M, int *N, int *nz);
    int mm_read_mtx_array_size(FILE *f, int *M, int *N);

    int mm_write_banner(FILE *f, MM_typecode matcode);
    int mm_write_mtx_crd_size(FILE *f, int M, int N, int nz);
    int mm_write_mtx_array_size(FILE *f, int M, int N);


    /********************* MM_typecode query fucntions ***************************/

    #define mm_is_matrix(typecode) ((typecode)[0]=='M')

    #define mm_is_sparse(typecode) ((typecode)[1]=='C')
    #define mm_is_coordinate(typecode)((typecode)[1]=='C')
    #define mm_is_dense(typecode) ((typecode)[1]=='A')
    #define mm_is_array(typecode) ((typecode)[1]=='A')

    #define mm_is_complex(typecode) ((typecode)[2]=='C')
    #define mm_is_real(typecode) ((typecode)[2]=='R')
    #define mm_is_pattern(typecode) ((typecode)[2]=='P')
    #define mm_is_integer(typecode) ((typecode)[2]=='I')

    #define mm_is_symmetric(typecode)((typecode)[3]=='S')
    #define mm_is_general(typecode) ((typecode)[3]=='G')
    #define mm_is_skew(typecode) ((typecode)[3]=='K')
    #define mm_is_hermitian(typecode)((typecode)[3]=='H')

    int mm_is_valid(MM_typecode matcode); /* too complex for a macro */


    /********************* MM_typecode modify fucntions ***************************/

    #define mm_set_matrix(typecode) ((*typecode)[0]='M')
    #define mm_set_coordinate(typecode) ((*typecode)[1]='C')
    #define mm_set_array(typecode) ((*typecode)[1]='A')
    #define mm_set_dense(typecode) mm_set_array(typecode)
    #define mm_set_sparse(typecode) mm_set_coordinate(typecode)

    #define mm_set_complex(typecode)((*typecode)[2]='C')
    #define mm_set_real(typecode) ((*typecode)[2]='R')
    #define mm_set_pattern(typecode)((*typecode)[2]='P')
    #define mm_set_integer(typecode)((*typecode)[2]='I')


    #define mm_set_symmetric(typecode)((*typecode)[3]='S')
    #define mm_set_general(typecode)((*typecode)[3]='G')
    #define mm_set_skew(typecode) ((*typecode)[3]='K')
    #define mm_set_hermitian(typecode)((*typecode)[3]='H')

    #define mm_clear_typecode(typecode) ((*typecode)[0]=(*typecode)[1]= \
    (*typecode)[2]=' ',(*typecode)[3]='G')

    #define mm_initialize_typecode(typecode) mm_clear_typecode(typecode)


    /********************* Matrix Market error codes ***************************/


    #define MM_COULD_NOT_READ_FILE 11
    #define MM_PREMATURE_EOF 12
    #define MM_NOT_MTX 13
    #define MM_NO_HEADER 14
    #define MM_UNSUPPORTED_TYPE 15
    #define MM_LINE_TOO_LONG 16
    #define MM_COULD_NOT_WRITE_FILE 17


    /******************** Matrix Market internal definitions ********************
    MM_matrix_typecode: 4-character sequence
    ojbect sparse/ data storage
    dense type scheme
    string position: [0] [1] [2] [3]
    Matrix typecode: M(atrix) C(oord) R(eal) G(eneral)
    A(array) C(omplex) H(ermitian)
    P(attern) S(ymmetric)
    I(nteger) K(kew)
    ***********************************************************************/

    #define MM_MTX_STR "matrix"
    #define MM_ARRAY_STR "array"
    #define MM_DENSE_STR "array"
    #define MM_COORDINATE_STR "coordinate"
    #define MM_SPARSE_STR "coordinate"
    #define MM_COMPLEX_STR "complex"
    #define MM_REAL_STR "real"
    #define MM_INT_STR "integer"
    #define MM_GENERAL_STR "general"
    #define MM_SYMM_STR "symmetric"
    #define MM_HERM_STR "hermitian"
    #define MM_SKEW_STR "skew-symmetric"
    #define MM_PATTERN_STR "pattern"


    /* high level routines */

    int mm_write_mtx_crd(char fname[], int M, int N, int nz, int I[], int J[],
    double val[], MM_typecode matcode);
    int mm_read_mtx_crd_data(FILE *f, int M, int N, int nz, int I[], int J[],
    double val[], MM_typecode matcode);
    int mm_read_mtx_crd_entry(FILE *f, int *I, int *J, double *real, double *img,
    MM_typecode matcode);

    int mm_read_unsymmetric_sparse(const char *fname, int *M_, int *N_, int *nz_,
    double **val_, int **I_, int **J_);



    #endif
    851 changes: 851 additions & 0 deletions wrapper.cpp
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,851 @@
    #include <algorithm>
    #include <cassert>
    #include <cfloat>
    #include <chrono>
    #include <cmath>
    #include <cstddef>
    #include <cstdint>
    #include <cstdio>
    #include <fstream>
    #include <ios>
    #include <iostream>
    #include <memory>
    #include <numeric>
    #include <random>
    #include <span>
    #include <string>
    #include <vector>

    extern "C" {
    #include "mmio.h"
    }

    #include <argparse/argparse.hpp>
    #include <fmt/format.h>
    #include <fmt/ranges.h>

    #include "mlir/ExecutionEngine/SparseTensor/COO.h"
    #include "mlir/ExecutionEngine/SparseTensor/Storage.h"
    #include "llvm/ADT/ArrayRef.h"
    #include "llvm/ADT/STLExtras.h"
    #include "llvm/ADT/Sequence.h"
    #include "llvm/Support/FormatVariadic.h"

    #if DISTRIBUTED
    #include "parttensor_mpi_backend/Storage.h"
    #include "parttensor_mpi_backend/parttensor_mpi_backend.h"
    #include <mpi.h>

    template <class... T>
    using PartTensorStorage = mlir::parttensor_mpi::MPITensorStorage<T...>;
    using mlir::parttensor_mpi::Product;
    using mlir::parttensor_mpi::SplitLoHiPoint;
    using mlir::parttensor_mpi::SubtractPoints;
    #endif // DISTRIBUTED

    using namespace mlir::sparse_tensor;
    using std::begin;
    using std::end;
    using std::make_unique;
    using std::span;
    using std::unique_ptr;
    using std::vector;

    using index_t = uint64_t;

    void getRss() {
    std::ifstream file("/proc/self/status");
    std::string line;
    #if DISTRIBUTED
    const auto rank = _mlir_ciface_mpi_getRank();
    fmt::print("{}: ", rank);
    #endif
    while (std::getline(file, line)) {
    if (line.starts_with("VmRSS:")) {
    line = line.substr(7);
    std::istringstream iss(line);
    long rss{};
    iss >> rss;
    fmt::print("{} GB\n", double(rss) / pow(2, 20));
    std::string rest;
    iss >> rest;
    assert(rest == "kB");
    break;
    }
    }
    }

    bool nearly_equal(float a, float b, float epsilon = 128 * FLT_EPSILON,
    float abs_th = FLT_MIN)
    // those defaults are arbitrary and could be removed
    {
    assert(std::numeric_limits<float>::epsilon() <= epsilon);
    assert(epsilon < 1.f);

    if (a == b)
    return true;

    auto diff = std::abs(a - b);
    auto norm =
    std::min((std::abs(a) + std::abs(b)), std::numeric_limits<float>::max());
    // or even faster: std::min(std::abs(a + b),
    // std::numeric_limits<float>::max()); keeping this commented out until I
    // update figures below
    return diff < std::max(abs_th, epsilon * norm);
    }
    // copied from: https://en.cppreference.com/w/cpp/types/numeric_limits/epsilon
    template <class T>
    std::enable_if_t<not std::numeric_limits<T>::is_integer, bool>
    equal_within_ulps(T x, T y, std::size_t n) {
    // Since `epsilon()` is the gap size (ULP, unit in the last place)
    // of floating-point numbers in interval [1, 2), we can scale it to
    // the gap size in interval [2^e, 2^{e+1}), where `e` is the exponent
    // of `x` and `y`.

    // If `x` and `y` have different gap sizes (which means they have
    // different exponents), we take the smaller one. Taking the bigger
    // one is also reasonable, I guess.
    const T m = std::min(std::fabs(x), std::fabs(y));

    // Subnormal numbers have fixed exponent, which is `min_exponent - 1`.
    const int exp = m < std::numeric_limits<T>::min()
    ? std::numeric_limits<T>::min_exponent - 1
    : std::ilogb(m);

    // We consider `x` and `y` equal if the difference between them is
    // within `n` ULPs.
    return std::abs(x - y) <=
    n * std::ldexp(std::numeric_limits<T>::epsilon(), exp);
    }
    using index_t = uint64_t;

    template <class T, std::size_t dims = 1> struct memref_t {
    uint64_t deadbeef;
    T *dataptr;
    uint64_t offset;
    uint64_t sizes[dims];
    uint64_t strides[dims];
    };

    using memref_1d_i64 = memref_t<uint64_t, 1>;
    using memref_2d_f32 = memref_t<float, 2>;
    using memref_3d_f32 = memref_t<float, 3>;

    template <class T> T multiply(span<T> sizes) {
    return std::accumulate(begin(sizes), end(sizes), size_t(1),
    std::multiplies<T>());
    }
    template <class T, size_t dims> void print_memref(memref_t<T, dims> out) {
    fmt::print("sizes: {}\n", out.sizes);
    fmt::print("strides: {}\n", out.strides);
    fmt::print("values: {}\n",
    span(out.dataptr, multiply(span(out.sizes, dims))));
    }

    extern "C" {
    extern void *part_tensor_softmax(void *A, void *out_file_name);
    extern void printSeperator();
    extern uint64_t _mlir_ciface_mpi_getRank();
    extern void delSparseTensor(void *tensor);
    }

    void printSeperator() {
    printf("---------------------------------------------------------------------"
    "-----------\n");
    }

    auto init_bin_matrix_sizes(const std::string &infile, unsigned &M, unsigned &N,
    unsigned &nnz) {
    auto base_name = infile.substr(0, infile.find_last_of('.'));
    auto info_file = base_name + ".info";
    auto file = fopen(info_file.c_str(), "r");
    if (!file) {
    std::cout << "Not valid file\n";
    exit(1);
    }
    int M_, N_, nnz_;
    auto nfieldsRead = fscanf(file, "%d %d %d", &M_, &N_, &nnz_);
    M = M_;
    N = N_;
    nnz = nnz_;
    // fmt::print("M: {} N: {} nnz: {}\n", M, N, nnz);
    if (nfieldsRead != 3) {
    std::cout << "Could not read mm size\n";
    exit(1);
    }
    fclose(file);
    }

    // std::tuple<int, int, std::unique_ptr<SparseTensorCOO<float>>>
    // init_matrix_a(std::string infile) {
    template <class FilterTy>
    auto init_bin_matrix_a(const std::string &infile, unsigned &M, unsigned &N,
    unsigned &nnz, unsigned Nf, FilterTy &&f,
    bool overrideValues = false) {
    std::mt19937 gen;
    gen.seed(0);
    std::uniform_real_distribution<> dis(0.0, 1.0);
    init_bin_matrix_sizes(infile, M, N, nnz);
    auto file = fopen(infile.c_str(), "rb");
    std::ifstream edgeData(
    (infile.substr(0, infile.find_first_of('.')) + ".edge.data.bin").c_str(),
    std::ios::binary);
    auto stCooA = std::make_unique<SparseTensorCOO<float>>(
    std::vector<index_t>({index_t(M), index_t(N), index_t(Nf)}));
    std::vector<index_t> coords(3);
    auto statusUpdate = 0;
    std::vector<float> vals(Nf);
    for (auto i : llvm::seq(0u, nnz)) {
    if (statusUpdate++ == 10000000) {
    fmt::print(".");
    statusUpdate = 0;
    }
    // The coordinates are 2d but tensor is 3d
    auto numRead = fread(coords.data(), sizeof(decltype(coords)::value_type),
    std::size(coords) - 1, file);
    assert(numRead == std::size(coords) && "Read error");
    const auto numBytesToRead = sizeof(decltype(vals)::value_type) * Nf;
    if (f(coords[0])) {
    edgeData.read(reinterpret_cast<char *>(vals.data()), numBytesToRead);
    for (auto j : llvm::seq(0u, Nf)) {
    coords.back() = j;
    stCooA->add(coords, vals[j]);
    }
    } else {
    edgeData.seekg(numBytesToRead, std::ios_base::cur);
    }
    }
    fmt::print("\n");
    assert(!(ferror(file) || feof(file)) && "Input file too short");
    auto numRead = fread(coords.data(), sizeof(decltype(coords)::value_type),
    std::size(coords), file);
    assert(numRead == 0 && "Input file too long");
    fclose(file);
    return stCooA;
    }
    auto init_matrix_a(const std::string &infile, unsigned &M, unsigned &N,
    unsigned &nnz, bool overrideValues = false) {
    std::mt19937 gen;
    gen.seed(0);
    std::uniform_real_distribution<> dis(0.0, 1.0);
    MM_typecode mc = {};
    auto file = fopen(infile.c_str(), "r");
    auto retCode = mm_read_banner(file, (MM_typecode *)&mc);
    if (retCode != 0) {
    std::cout << "Not valid file\n";
    exit(1);
    }
    int M_, N_, nnz_;
    retCode = mm_read_mtx_crd_size(file, &M_, &N_, &nnz_);
    M = M_;
    N = N_;
    nnz = nnz_;
    if (retCode != 0) {
    std::cout << "Could not read mm size\n";
    exit(1);
    }
    auto stCooA = std::make_unique<SparseTensorCOO<float>>(
    std::vector<index_t>({index_t(M), index_t(N)}));
    const bool IsSymmetric = mm_is_symmetric(mc);
    const bool IsPattern = mm_is_pattern(mc);
    if (IsPattern)
    overrideValues = true;
    for (auto i : llvm::seq(0, nnz_)) {
    index_t r, c;
    float v;
    if (IsPattern)
    fscanf(file, "%ld %ld", &r, &c);
    else
    fscanf(file, "%ld %ld %f", &r, &c, &v);
    v = overrideValues ? dis(gen) : v;
    r--;
    c--; // mtx is 1 based
    stCooA->add({r, c}, v);
    if (IsSymmetric && r != c) {
    stCooA->add({c, r}, v);
    }
    }
    return stCooA;
    }
    template <class T>
    void init_feats(SparseTensorStorage<index_t, index_t, T> *tensor,
    const std::string &infile, unsigned N, unsigned Dh, unsigned Nh,
    unsigned beginOffset = 0, bool ValidateFileSizes = false,
    bool overrideValues = false) {
    std::mt19937 gen;
    gen.seed(0);
    std::uniform_real_distribution<> dis(0.0, 1.0);
    std::ifstream file(
    (infile.substr(0, infile.find_first_of('.')) + ".vert.data.bin").c_str(),
    std::ios::binary);
    if (!file) {
    std::cout << "Not valid file\n";
    exit(1);
    }
    std::streampos fileBegin =
    file.tellg(); // Get the current position (which is the file size)
    file.seekg(0, std::ios::end); // Seek to the end of the file
    size_t fileSize = file.tellg() - fileBegin;
    file.seekg(beginOffset, std::ios::beg); // Seek back to the base offset

    assert(!ValidateFileSizes ||
    fileSize == N * Dh * Nh * sizeof(T) && "File size mismatch");
    std::vector<T> *vals;
    tensor->getValues(&vals);
    file.read(reinterpret_cast<char *>(vals->data()), sizeof(T) * vals->size());
    // ValidateFileSizes -> file.good()
    assert((!ValidateFileSizes || file.good()) && "File read error");
    return;
    }
    template <class T>
    auto init_feats(const std::string &infile, unsigned &N, unsigned Dh,
    unsigned Nh, unsigned beginOffset = 0,
    bool overrideValues = false) {
    std::mt19937 gen;
    gen.seed(0);
    std::uniform_real_distribution<> dis(0.0, 1.0);
    std::ifstream file(infile, std::ios::binary);
    if (!file) {
    std::cout << "Not valid file\n";
    exit(1);
    }
    std::streampos fileBegin =
    file.tellg(); // Get the current position (which is the file size)
    file.seekg(0, std::ios::end); // Seek to the end of the file
    size_t fileSize = file.tellg() - fileBegin;
    file.seekg(beginOffset, std::ios::beg); // Seek back to the base offset

    assert(fileSize == N * Dh * Nh * sizeof(T) && "File size mismatch");
    auto stCooA = std::make_unique<SparseTensorCOO<T>>(
    std::vector<index_t>({index_t(N), index_t(Dh), index_t(Nh)}));
    for (auto n : llvm::seq(0u, N))
    for (auto dh : llvm::seq(0u, Dh))
    for (auto nh : llvm::seq(0u, Nh)) {
    T val;
    file.read(reinterpret_cast<char *>(&val), sizeof(T));
    val = overrideValues ? dis(gen) : val;
    stCooA->add({n, dh, nh}, val);
    assert(file.good() && "File read error");
    }
    return stCooA;
    }

    std::unique_ptr<SparseTensorCOO<float>> init_matrix_a(index_t rowSize = 4) {
    assert(0 && "Need to init with multiple features!");
    auto dims = std::vector<size_t>{rowSize, rowSize};
    auto stCooA = std::make_unique<SparseTensorCOO<float>>(dims);
    for (auto i : llvm::seq(0ul, rowSize))
    for (auto j : llvm::seq(0ul, rowSize))
    stCooA->add({i, j}, float(i * rowSize + j));
    return stCooA;
    }

    template <class T>
    auto init_random(SparseTensorCOO<T> &vec, size_t N, size_t seed = 0) {
    std::mt19937 gen;
    gen.seed(seed);
    std::uniform_real_distribution<> dis(T(0), T(100));
    auto vals = vec.getElements();
    for (auto i : llvm::seq(0ul, N))
    (vals.at(i)).value = (dis(gen));
    }

    template <class T>
    auto init_random(std::vector<T> &vec, size_t N, size_t seed = 0) {
    std::mt19937 gen;
    gen.seed(seed);
    std::uniform_real_distribution<> dis(T(0), T(100));
    for (auto i : llvm::seq(0ul, N))
    vec.push_back(dis(gen));
    }

    template <> auto init_random(std::vector<uint8_t> &vec, size_t N, size_t seed) {
    std::mt19937 gen;
    gen.seed(seed);
    std::bernoulli_distribution dis(.5);
    for (auto i : llvm::seq(0ul, N))
    vec.push_back(dis(gen));
    }

    /// Initialize a 2D sparse square tensor with random values
    /// @param vec: SparseTensorCOO<float> to be initialized
    /// @param N: rows and columns of the 2D tensor
    /// @param seed: seed for random number generator
    /// @param density: density of the tensor
    template <class T>
    auto init_random_sparse_2d(SparseTensorCOO<float> &stCoo, size_t N, size_t seed,
    double density = 0.5) {
    std::mt19937 gen;
    gen.seed(seed);
    std::bernoulli_distribution dis(density);
    for (auto i : llvm::seq(0ul, N * N))
    if (dis(gen))
    stCoo.add({i / N, i % N}, T{1});
    }

    template <class T> void write_tensor(span<T> vec, std::string filename) {
    std::ofstream(filename, std::ios::binary)
    .write(reinterpret_cast<char *>(vec.data()), sizeof(T) * vec.size());
    }

    template <class T>
    void write_vector_to_stcoo3d(const std::vector<T> &vec,
    SparseTensorCOO<T> &stCoo) {
    auto dims = stCoo.getDimSizes();
    auto [N, Dh, Nh] = std::tuple(dims[0], dims[1], dims[2]);
    llvm::for_each(llvm::enumerate(vec), [&stCoo, N, Dh, Nh](auto pair) {
    auto [index, val] = pair;
    auto nh = index % Nh;
    auto rest = (index - nh) / Nh;
    auto dh = rest % Dh;
    auto n = (rest - dh) / Dh;
    stCoo.add({n, dh, nh}, val);
    });
    }

    #if DISTRIBUTED
    auto get2DPartition(size_t rows, size_t cols, size_t nh, size_t rowParts,
    size_t colParts) {
    std::vector<index_t> partitionPlan;
    auto rowPartitionSize = rows / rowParts;
    auto colPartitionSize = cols / colParts;
    assert(rows % rowParts == 0 && "rows % rowParts != 0");
    assert(cols % colParts == 0 && "cols % colParts != 0");
    for (int j = 0; j < cols; j += colPartitionSize) {
    for (int i = 0; i < rows; i += rowPartitionSize) {
    if (false && _mlir_ciface_mpi_getRank() == 0)
    std::cout << "(" << i << "," << j << ") -> (" << i + rowPartitionSize
    << "," << j + colPartitionSize << ") \n";
    partitionPlan.push_back(index_t(i));
    partitionPlan.push_back(index_t(j));
    if (nh)
    partitionPlan.push_back(index_t(0));
    partitionPlan.push_back(index_t(i + rowPartitionSize));
    partitionPlan.push_back(index_t(j + colPartitionSize));
    if (nh)
    partitionPlan.push_back(index_t(nh));
    }
    }
    return partitionPlan;
    }
    #endif

    #define VEC_TO_MEMREF2D(v, size_0, size_1) \
    (void *)0xdeadbeef, v.data(), 0, size_0, size_1, size_1, 1

    #define VEC_TO_MEMREF3D(v, size_0, size_1, size_2) \
    (void *)0xdeadbeef, v.data(), 0, size_0, size_1, size_2, size_2 *size_1, \
    size_2, 1

    #define TENSOR_2D_ARG(t) \
    void *t##_deadbeef, void *t##_dataptr, uint64_t t##_offset, \
    uint64_t t##_sizes_0, uint64_t t##_sizes_1, uint64_t t##_strides_0, \
    uint64_t t##_strides_1

    #define TENSOR_3D_ARG(t) \
    void *t##_deadbeef, void *t##_dataptr, uint64_t t##_offset, \
    uint64_t t##_sizes_0, uint64_t t##_sizes_1, uint64_t t##_sizes_2, \
    uint64_t t##_strides_0, uint64_t t##_strides_1, uint64_t t##_strides_2

    extern "C" void *sparse_mha(void *A, void *Q, void *K, void *V, void *);
    extern "C" void *pte_local_bsddmm(void *A, void *Q, void *K);
    extern "C" void *pte_local_bspmm(void *A, void *V);
    extern "C" void *pte_bsddmm(void *A, void *Q, void *K, index_t n1, index_t n2,
    index_t dh, index_t nh);
    extern "C" void *pte_local_sparse_mha(void *A, void *Q, void *K, void *V);
    extern "C" void *pte_sparse_mha(void *A, void *Q, void *K, void *V, index_t n1,
    index_t n2, index_t dh, index_t nh);

    extern "C" void lapis_initialize();
    extern "C" void lapis_finalize();

    void validate_cli(const argparse::ArgumentParser &p) {
    if (p.is_used("-n") && p.is_used("-i")) {
    fmt::print("Only one of -n OR -i can be specified\n");
    exit(1);
    }
    if (p.is_used("-d") && p.is_used("-i")) {
    fmt::print("Only one of -d OR -i can be specified\n");
    exit(1);
    }
    if (!p.is_used("--check") && !p.is_used("--ntimes")) {
    fmt::print("Atleast --check or --ntimes should be specified\n");
    exit(1);
    }
    }

    int main(int argc, char **argv) {
    using namespace mlir::sparse_tensor;
    #if DISTRIBUTED
    MPI_Init(&argc, &argv);
    #endif
    #if defined(USE_KOKKOS)
    lapis_initialize();
    #endif

    unsigned N = 4, Nh = 2, Dh = 2, Nnz{}, nTimes = 1;
    unsigned Nparts = 1;
    unsigned logRank = 99;
    // bool PerfOnly = false, LocalOnly = false, CheckCorrectness = false;
    bool LocalOnly = false, DistOnly = false, CheckCorrectness = false,
    ValidateFileSizes = false;
    double density = 0.5;
    std::string infile;
    argparse::ArgumentParser program(argv[0]);
    program.add_argument("-dh").store_into(Dh).help("head size, default = 2");
    program.add_argument("-nh").store_into(Nh).help("#head, default = 2");
    program.add_argument("-n").store_into(N).help("#nodes, default = 4");
    program.add_argument("-i").store_into(infile).help("Input File");
    program.add_argument("--ntimes").store_into(nTimes).help("rerun n times = 1");
    program.add_argument("-d", "--density")
    .store_into(density)
    .help("sparsity density, default = 0.5");
    program.add_argument("-np", "--nparts")
    .store_into(Nparts)
    .help("number of parts, default = 1");

    program.add_argument("--logrank")
    .store_into(logRank)
    .help("log when rank == logrank, default = 99");
    program.add_argument("--check")
    .store_into(CheckCorrectness)
    .help("Do correctness check, default = false")
    .flag();
    program.add_argument("--validate-file-sizes")
    .store_into(ValidateFileSizes)
    .help("Make sure Nh and Dh are compatible with the input file sizes, "
    "default = false")
    .flag();
    program.add_argument("--dist-only")
    .store_into(DistOnly)
    .help("Skip local run, default = false")
    .flag();
    program.add_argument("--local-only")
    .store_into(LocalOnly)
    .help("Skip distributed run, default = false")
    .flag();
    try {
    program.parse_args(argc, argv);
    } catch (const std::exception &err) {
    std::cerr << err.what() << std::endl;
    std::cerr << program;
    return 1;
    }

    validate_cli(program);

    decltype(init_matrix_a(4)) stCooA{}, stCooQ{};

    #if DISTRIBUTED
    const auto rank = _mlir_ciface_mpi_getRank();
    #endif // DISTRIBUTED
    if (!infile.empty()) {
    decltype(N) M;
    if (DistOnly) {
    #if DISTRIBUTED
    if (!infile.ends_with(".bin")) {
    fmt::print("Only binary input supported for distributed run\n");
    exit(1);
    }
    init_bin_matrix_sizes(infile, M, N, Nnz);
    auto aPartitionPlan = get2DPartition(N, N, 0, Nparts, 1);
    const std::vector<size_t> dims = {N, N};
    auto myaPartSpec =
    std::span(aPartitionPlan)
    .subspan(rank * std::size(dims) * 2, std::size(dims) * 2);
    auto [aaLo, aaHi] = SplitLoHiPoint(
    llvm::ArrayRef(myaPartSpec.data(), std::size(myaPartSpec)));
    // fmt::println("Rank: {} aaHi: {} aaLo: {}", rank, aaHi, aaLo);
    stCooA = init_bin_matrix_a(infile, M, N, Nnz, Dh * Nh, [=](auto n1) {
    return (aaLo[0] <= n1 && n1 < aaHi[0]);
    });
    #endif // DISTRIBUTED
    } else {
    stCooA = infile.ends_with(".bin")
    ? init_bin_matrix_a(infile, M, N, Nnz, Nh,
    [](auto) { return true; })
    : init_matrix_a(infile, M, N, Nnz);
    }
    assert(M == N);
    } else {
    assert(0);
    }
    fmt::println(">>> Loaded SparseMat");
    fflush(nullptr);

    system("date");

    assert(N % Nparts == 0 && "Problem size (n) should divisible by nparts");
    const size_t TensorSize = N * Dh * Nh, MatSize = N * N;
    const std::vector<size_t> dims = {N, N, Nh};
    const std::vector<size_t> featDims = {N, Dh, Nh};
    const bool AllocateWholeTensors = CheckCorrectness || (!DistOnly);

    if (infile.empty()) {
    stCooA = make_unique<SparseTensorCOO<float>>(dims);
    init_random_sparse_2d<float>(*stCooA, N, 1, density);
    }
    stCooQ = make_unique<SparseTensorCOO<float>>(featDims);
    const LevelType denseLvl =
    *mlir::sparse_tensor::buildLevelType(LevelFormat::Dense, {});
    const LevelType compressedLvl = *mlir::sparse_tensor::buildLevelType(
    LevelFormat::Compressed, false, false);
    const LevelType kCSR[] = {denseLvl, compressedLvl};
    const LevelType kCSRV[] = {denseLvl, compressedLvl, denseLvl};
    const LevelType kDenseV[] = {denseLvl, denseLvl, denseLvl};
    const uint64_t src2tgt[] = {0, 1, 2};

    auto stA = AllocateWholeTensors
    ? SparseTensorStorage<index_t, index_t, float>::newFromCOO(
    std::size(dims), dims.data(), std::size(dims),
    dims.data(), kCSRV, src2tgt, src2tgt, stCooA.get())
    : nullptr;

    getRss();
    if (LocalOnly)
    stCooA.reset();
    getRss();

    fmt::println(">>> Allocated A");
    fflush(nullptr);

    system("date");
    auto stQ =
    AllocateWholeTensors
    ? SparseTensorStorage<index_t, index_t, float>::newFromCOO(
    std::size(featDims), featDims.data(), std::size(featDims),
    featDims.data(), kDenseV, src2tgt, src2tgt, stCooQ.get())
    : nullptr;
    if (stQ)
    init_feats(stQ, infile, N, Dh, Nh);
    if (stQ) {
    fmt::println(">>> Allocated Q");
    fflush(nullptr);

    system("date");
    }
    auto stK =
    AllocateWholeTensors
    ? SparseTensorStorage<index_t, index_t, float>::newFromCOO(
    std::size(featDims), featDims.data(), std::size(featDims),
    featDims.data(), kDenseV, src2tgt, src2tgt, stCooQ.get())
    : nullptr;
    if (stK)
    init_feats(stK, infile, N, Dh, Nh);
    if (stK) {
    fmt::println(">>> Allocated K");
    fflush(nullptr);

    system("date");
    }
    auto stV =
    AllocateWholeTensors
    ? SparseTensorStorage<index_t, index_t, float>::newFromCOO(
    std::size(featDims), featDims.data(), std::size(featDims),
    featDims.data(), kDenseV, src2tgt, src2tgt, stCooQ.get())
    : nullptr;
    if (stV)
    init_feats(stV, infile, N, Dh, Nh);
    if (stV) {
    fmt::println(">>> Allocated V");
    fflush(nullptr);

    system("date");
    }
    stCooQ.reset();

    std::vector<std::chrono::milliseconds> localTimes;
    auto start = std::chrono::high_resolution_clock::now();
    auto gold_out =
    AllocateWholeTensors
    ? static_cast<SparseTensorStorage<index_t, index_t, float> *>(
    pte_local_bspmm(stA, stV))
    : nullptr;
    auto end = std::chrono::high_resolution_clock::now();
    localTimes.push_back(
    std::chrono::duration_cast<std::chrono::milliseconds>(end - start));
    if (!CheckCorrectness && gold_out) {
    std::vector<float> *o;
    gold_out->getValues(&o);
    auto outfile = infile.substr(infile.find_last_of('/') + 1);
    outfile =
    fmt::format("{}.res", outfile.substr(0, outfile.find_first_of('.')));
    write_tensor<float>(*o, outfile.c_str());
    delSparseTensor((void *)gold_out);
    }

    if (AllocateWholeTensors) {
    fmt::println(">>> Done running local kernel");
    fflush(nullptr);
    system("date");
    getRss();
    }
    for (auto i : llvm::seq(0u, !AllocateWholeTensors ? 0u : nTimes)) {
    auto start = std::chrono::high_resolution_clock::now();
    auto out = static_cast<SparseTensorStorage<index_t, index_t, float> *>(
    pte_local_bspmm(stA, stV));
    auto end = std::chrono::high_resolution_clock::now();
    localTimes.push_back(
    std::chrono::duration_cast<std::chrono::milliseconds>(end - start));
    delSparseTensor((void *)out);
    }
    // print average time
    if (AllocateWholeTensors)
    fmt::print("Local time: {}ms\n",
    float(std::accumulate(localTimes.begin(), localTimes.end(),
    std::chrono::milliseconds(0))
    .count()) /
    localTimes.size());

    delSparseTensor((void *)stA);
    delSparseTensor((void *)stQ);
    delSparseTensor((void *)stK);
    delSparseTensor((void *)stV);
    if (LocalOnly) {
    #if defined(USE_KOKKOS)
    lapis_finalize();
    #endif
    #if DISTRIBUTED
    MPI_Finalize();
    #endif
    return 0;
    }
    #if DISTRIBUTED
    auto aPartitionPlan = get2DPartition(N, N, 0, Nparts, 1);
    auto ptA = PartTensorStorage<index_t, index_t, float>::newFromCOO(
    std::size(aPartitionPlan), aPartitionPlan.data(), std::size(dims),
    dims.data(), kCSR, stCooA.get());
    stCooA.reset();
    fmt::println(">>> {}: allocated ptA", rank);
    fflush(nullptr);

    system("date");
    getRss();
    auto qPartitionPlan = get2DPartition(N, Dh, Nh, Nparts, 1);
    auto myPartSpec =
    std::span(qPartitionPlan)
    .subspan(rank * std::size(featDims) * 2, std::size(featDims) * 2);
    auto fileOffset = myPartSpec[0] * Product(featDims, 1) * sizeof(float);
    auto partFeatDims = std::vector<size_t>(3);
    auto [aLo, aHi] =
    SplitLoHiPoint(llvm::ArrayRef(myPartSpec.data(), std::size(myPartSpec)));
    // fmt::print("Rank: {} aHi: {} aLo: {}\n", rank, aHi, aLo);
    SubtractPoints(partFeatDims, aHi, aLo);
    // fmt::print("Rank: {} partSpec: {} fileOffset: {} partFeatDims: {}\n", rank,
    // myPartSpec, fileOffset, partFeatDims);

    auto stPartQ = SparseTensorStorage<index_t, index_t, float>::newFromCOO(
    std::size(partFeatDims), partFeatDims.data(), std::size(partFeatDims),
    partFeatDims.data(), kDenseV, src2tgt, src2tgt, stCooQ.get());
    fmt::println(">>> {}: allocated ptQ", rank);
    fflush(nullptr);

    system("date");
    getRss();
    auto stPartK = SparseTensorStorage<index_t, index_t, float>::newFromCOO(
    std::size(partFeatDims), partFeatDims.data(), std::size(partFeatDims),
    partFeatDims.data(), kDenseV, src2tgt, src2tgt, stCooQ.get());
    auto stPartV = SparseTensorStorage<index_t, index_t, float>::newFromCOO(
    std::size(partFeatDims), partFeatDims.data(), std::size(partFeatDims),
    partFeatDims.data(), kDenseV, src2tgt, src2tgt, stCooQ.get());
    init_feats(stPartQ, infile, partFeatDims[0], partFeatDims[1], partFeatDims[2],
    fileOffset);
    init_feats(stPartK, infile, partFeatDims[0], partFeatDims[1], partFeatDims[2],
    fileOffset);
    init_feats(stPartV, infile, partFeatDims[0], partFeatDims[1], partFeatDims[2],
    fileOffset);
    auto ptQ =
    PartTensorStorage<index_t, index_t, float>::newFromSparseTensorStorage(
    std::size(qPartitionPlan), qPartitionPlan.data(), std::size(featDims),
    featDims.data(), stPartQ);
    auto ptK =
    PartTensorStorage<index_t, index_t, float>::newFromSparseTensorStorage(
    std::size(qPartitionPlan), qPartitionPlan.data(), std::size(featDims),
    featDims.data(), stPartK);
    auto ptV =
    PartTensorStorage<index_t, index_t, float>::newFromSparseTensorStorage(
    std::size(qPartitionPlan), qPartitionPlan.data(), std::size(featDims),
    featDims.data(), stPartV);

    fmt::println(">>> {}: allocated ptV", rank);
    fflush(nullptr);

    system("date");
    getRss();
    localTimes.clear();
    start = std::chrono::high_resolution_clock::now();
    auto out = static_cast<SparseTensorStorage<index_t, index_t, float> *>(
    pte_sparse_mha(ptA, ptQ, ptK, ptV, N, N, Dh, Nh));
    end = std::chrono::high_resolution_clock::now();
    localTimes.push_back(
    std::chrono::duration_cast<std::chrono::milliseconds>(end - start));
    fmt::println(">>> {}: Done dist exec", rank);
    fflush(nullptr);

    system("date");
    getRss();
    if (!CheckCorrectness) {
    std::vector<float> *o;
    out->getValues(&o);
    int comm_size;
    // get comm size
    MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
    auto outfile = infile.substr(infile.find_last_of('/') + 1);
    outfile = fmt::format("{}.{}.{}.res",
    outfile.substr(0, outfile.find_first_of('.')), rank,
    comm_size);
    write_tensor<float>(*o, outfile.c_str());
    delSparseTensor((void *)out);
    }
    for (auto i : llvm::seq(0u, CheckCorrectness ? 0u : nTimes)) {
    auto start = std::chrono::high_resolution_clock::now();
    auto out = static_cast<SparseTensorStorage<index_t, index_t, float> *>(
    pte_sparse_mha(ptA, ptQ, ptK, ptV, N, N, Dh, Nh));
    auto end = std::chrono::high_resolution_clock::now();
    localTimes.push_back(
    std::chrono::duration_cast<std::chrono::milliseconds>(end - start));
    delSparseTensor((void *)out);
    }
    // print average time
    fmt::print("Dist time: {}ms\n",
    float(std::accumulate(localTimes.begin(), localTimes.end(),
    std::chrono::milliseconds(0))
    .count()) /
    localTimes.size());

    if (CheckCorrectness) {
    std::vector<float> *o, *g;
    out->getValues(&o);
    gold_out->getValues(&g);
    size_t partSize = (N * Dh * Nh / Nparts);
    auto partBegin = partSize * _mlir_ciface_mpi_getRank();
    auto partEnd = partBegin + partSize;
    for (auto i : llvm::seq(partBegin, partEnd)) {
    auto localI = i - partBegin;
    if (g->at(i) != o->at(localI)) {
    fmt::print("Mismatch at rank {} index: {} expected: {} got: {}\n", rank,
    i, g->at(i), o->at(localI));
    return 1;
    }
    }
    }
    // fmt::print("Q tensor: {}\n", Q);
    // fmt::print("K tensor: {}\n", K);
    // fmt::print("V tensor: {}\n", V);
    #define PRINT_OUTPUT(out) \
    { \
    std::vector<float> *o; \
    out->getValues(&o); \
    fmt::print(#out " tensor: {}\n", *o); \
    write_tensor<float>(*o, #out ".dat"); \
    }
    // PRINT_OUTPUT(out)
    // PRINT_OUTPUT(gold_out)
    #undef PRINT_OUTPUT
    #if defined(USE_KOKKOS)
    lapis_finalize();
    #endif
    MPI_Finalize();
    #endif
    return 0;
    }