Created
May 18, 2020 17:11
-
-
Save wush978/3a0e02b64c554546868402a517cc3c92 to your computer and use it in GitHub Desktop.
Revisions
-
wush978 created this gist
May 18, 2020 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,73 @@ %%cython --annotate --cplus --compile-args=-fopenmp --link-args=-fopenmp cimport cython from cython.parallel cimport prange, parallel, threadid import numpy as np cimport numpy as np import scipy cimport openmp #cdef extern from "<algorithm>" namespace "std": # cdef void sort[RI](RI first, RI last) cdef extern from "<parallel/algorithm>" namespace "__gnu_parallel": cdef void sort[RI](RI first, RI last) except + cdef void unique_copy[II, OI](II first, II last, OI d_first) except + from libcpp.vector cimport vector from libcpp.utility cimport pair ctypedef pair[np.int64_t,np.int64_t] Index cdef void __csr_to_csc( np.int64_t *src_indptr, np.int64_t *src_indices, np.int64_t *dst_indptr, np.int64_t *dst_indices, size_t nrow, size_t ncol, size_t nnz, ): cdef vector[Index] index cdef vector[vector[np.int64_t]] buffer = vector[vector[np.int64_t]](openmp.omp_get_max_threads()) cdef size_t i, j, nthread with nogil: index.resize(nnz) for i in prange(nrow): for j in range(src_indptr[i],src_indptr[i+1]): index[j].second = i # row index[j].first = src_indices[j] # col sort[vector[Index].iterator](index.begin(), index.end()) with nogil, parallel(): buffer[threadid()].resize(ncol) for i in prange(nnz): buffer[threadid()][index[i].first] += 1 dst_indices[i] = index[i].second for i in prange(ncol): for j in range(buffer.size()): if buffer[j].size() > 0: dst_indptr[i+1] += buffer[j][i] for i in range(ncol): dst_indptr[i+1] = dst_indptr[i+1] + dst_indptr[i] cdef np.int64_t* getp(np.ndarray[np.int64_t, ndim = 1] arr): return &arr[0] def csr_to_csc(m): if not type(m) is scipy.sparse.csr.csr_matrix: raise RuntimeError("m is not a csr_matrix") if not m.indptr.dtype == np.int64: raise RuntimeError("The indptr is not int64") assert(m.indices.dtype == np.int64) if not np.all(m.data == 1): raise RuntimeError("The data is not all 1") dst_indptr = np.zeros(m.shape[1] + 1, dtype = np.int64) dst_indices = np.zeros(len(m.indices), dtype = np.int64) __csr_to_csc( getp(m.indptr), getp(m.indices), getp(dst_indptr), getp(dst_indices), m.shape[0], m.shape[1], len(m.indices), ) return scipy.sparse.csc_matrix((m.data, dst_indices, dst_indptr), shape = m.shape)