#!/usr/bin/env python cimport cython import numpy as np cimport numpy as np from numpy.compat import asbytes from numpy import int from clapack cimport dgesdd_ @cython.boundscheck(False) cdef inline int int_max(int a, int b): return a if a >= b else b @cython.boundscheck(False) cdef inline int int_min(int a, int b): return a if a <= b else b @cython.boundscheck(False) def fastsvd(mat): cdef np.ndarray[double, ndim=2] a = mat cdef int M = a.shape[0] cdef int N = a.shape[1] cdef np.ndarray[double, ndim=1] s = np.empty(M, dtype=np.double, order='F') cdef np.ndarray[double, ndim=2] u = np.empty((M, M), dtype=np.double, order='F') cdef np.ndarray[double, ndim=2] vt = np.empty((N, N), dtype=np.double, order='F') cdef int info = 0 cdef char jobz = 'A' cdef int lwork = int_min(M,N)*(6+4*int_min(M,N))+int_max(M,N) cdef np.ndarray[double, ndim=1] work = np.zeros((lwork,), dtype=np.double, order='F') cdef int tiwork = 8*int_min(M,N) cdef np.ndarray iwork = np.zeros((tiwork,), dtype=int, order='F') cdef int ldvt = M cdef int ldu = M cdef int lda = M dgesdd_(&jobz, &M, &N, a.data, &lda, s.data, u.data, &ldu, vt.data, &ldvt, work.data, &lwork, iwork.data, &info) U = np.array(u) S = np.asfortranarray(s) VT = np.array(vt) return U, S, VT