Skip to content

Instantly share code, notes, and snippets.

@wush978
Created May 18, 2020 17:11
Show Gist options
  • Save wush978/3a0e02b64c554546868402a517cc3c92 to your computer and use it in GitHub Desktop.
Save wush978/3a0e02b64c554546868402a517cc3c92 to your computer and use it in GitHub Desktop.

Revisions

  1. wush978 created this gist May 18, 2020.
    73 changes: 73 additions & 0 deletions csr2csc.pyx
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,73 @@
    %%cython --annotate --cplus --compile-args=-fopenmp --link-args=-fopenmp
    cimport cython
    from cython.parallel cimport prange, parallel, threadid
    import numpy as np
    cimport numpy as np
    import scipy
    cimport openmp
    #cdef extern from "<algorithm>" namespace "std":
    # cdef void sort[RI](RI first, RI last)

    cdef extern from "<parallel/algorithm>" namespace "__gnu_parallel":
    cdef void sort[RI](RI first, RI last) except +
    cdef void unique_copy[II, OI](II first, II last, OI d_first) except +

    from libcpp.vector cimport vector
    from libcpp.utility cimport pair
    ctypedef pair[np.int64_t,np.int64_t] Index

    cdef void __csr_to_csc(
    np.int64_t *src_indptr,
    np.int64_t *src_indices,
    np.int64_t *dst_indptr,
    np.int64_t *dst_indices,
    size_t nrow,
    size_t ncol,
    size_t nnz,
    ):
    cdef vector[Index] index
    cdef vector[vector[np.int64_t]] buffer = vector[vector[np.int64_t]](openmp.omp_get_max_threads())
    cdef size_t i, j, nthread
    with nogil:
    index.resize(nnz)
    for i in prange(nrow):
    for j in range(src_indptr[i],src_indptr[i+1]):
    index[j].second = i # row
    index[j].first = src_indices[j] # col
    sort[vector[Index].iterator](index.begin(), index.end())
    with nogil, parallel():
    buffer[threadid()].resize(ncol)
    for i in prange(nnz):
    buffer[threadid()][index[i].first] += 1
    dst_indices[i] = index[i].second
    for i in prange(ncol):
    for j in range(buffer.size()):
    if buffer[j].size() > 0:
    dst_indptr[i+1] += buffer[j][i]
    for i in range(ncol):
    dst_indptr[i+1] = dst_indptr[i+1] + dst_indptr[i]


    cdef np.int64_t* getp(np.ndarray[np.int64_t, ndim = 1] arr):
    return &arr[0]

    def csr_to_csc(m):
    if not type(m) is scipy.sparse.csr.csr_matrix:
    raise RuntimeError("m is not a csr_matrix")
    if not m.indptr.dtype == np.int64:
    raise RuntimeError("The indptr is not int64")
    assert(m.indices.dtype == np.int64)
    if not np.all(m.data == 1):
    raise RuntimeError("The data is not all 1")
    dst_indptr = np.zeros(m.shape[1] + 1, dtype = np.int64)
    dst_indices = np.zeros(len(m.indices), dtype = np.int64)
    __csr_to_csc(
    getp(m.indptr),
    getp(m.indices),
    getp(dst_indptr),
    getp(dst_indices),
    m.shape[0],
    m.shape[1],
    len(m.indices),
    )
    return scipy.sparse.csc_matrix((m.data, dst_indices, dst_indptr), shape = m.shape)