Created
June 1, 2014 17:11
-
-
Save pmadhyastha/7540653ff4fbc35155dd to your computer and use it in GitHub Desktop.
Computing svd of a matrix using lapack's dgesdd available here: http://www.netlib.org/lapack/explore-html/db/db4/dgesdd_8f.html
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| from numpy cimport int | |
| cdef extern void dgesdd_(char *jobz, int *m, int *n, | |
| double a[], int *lda, double s[], double u[], | |
| int *ldu, double vt[], int *ldvt, double work[], | |
| int *lwork, int iwork[], int *info) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import pyximport | |
| pyximport.install(setup_args = {'options' : | |
| {'build_ext' : | |
| {'libraries' : 'lapack', | |
| 'include_dirs' : np.get_include(), | |
| }}}) | |
| from lsvd import fastsvd | |
| x = np.random.rand(100, 100) | |
| u, s, vt = fastsvd(x) ## gives wrong answer i.e., np.dot(u, np.dot(np.diag(s), vt) =! x | |
| #when x is not a square matrix it gives errors: | |
| x = np.random.rand(100, 200) | |
| u, s, vt = fastsvd(x) ###THIS GIVES AN ERROR: | |
| # | |
| .pyxbld/temp.linux-x86_64-2.7/pyrex/lsvd.c:1834)() | |
| #/usr/lib/python2.7/dist-packages/numpy/core/numeric.pyc in asfortranarray(a, dtype) | |
| # 582 | |
| # 583 """ | |
| #--> 584 return array(a, dtype, copy=False, order='F', ndmin=1) | |
| # 585 | |
| # 586 def require(a, dtype=None, requirements=None): | |
| #ValueError: On entry to DGESDD parameter number 10 had an illegal value |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| cimport cython | |
| import numpy as np | |
| cimport numpy as np | |
| from numpy.compat import asbytes | |
| from numpy import int | |
| from clapack cimport dgesdd_ | |
| @cython.boundscheck(False) | |
| cdef inline int int_max(int a, int b): return a if a >= b else b | |
| @cython.boundscheck(False) | |
| cdef inline int int_min(int a, int b): return a if a <= b else b | |
| @cython.boundscheck(False) | |
| def fastsvd(mat): | |
| cdef np.ndarray[double, ndim=2] a = mat | |
| cdef int M = a.shape[0] | |
| cdef int N = a.shape[1] | |
| cdef np.ndarray[double, ndim=1] s = np.empty(M, dtype=np.double, order='F') | |
| cdef np.ndarray[double, ndim=2] u = np.empty((M, M), dtype=np.double, order='F') | |
| cdef np.ndarray[double, ndim=2] vt = np.empty((N, N), dtype=np.double, order='F') | |
| cdef int info = 0 | |
| cdef char jobz = 'A' | |
| cdef int lwork = int_min(M,N)*(6+4*int_min(M,N))+int_max(M,N) | |
| cdef np.ndarray[double, ndim=1] work = np.zeros((lwork,), dtype=np.double, order='F') | |
| cdef int tiwork = 8*int_min(M,N) | |
| cdef np.ndarray iwork = np.zeros((tiwork,), dtype=int, order='F') | |
| cdef int ldvt = M | |
| cdef int ldu = M | |
| cdef int lda = M | |
| dgesdd_(&jobz, &M, &N, <double *> a.data, &lda, | |
| <double *> s.data, <double *> u.data, &ldu, <double *> vt.data, | |
| &ldvt, <double *> work.data, &lwork, <int *> iwork.data, &info) | |
| U = np.array(u) | |
| S = np.asfortranarray(s) | |
| VT = np.array(vt) | |
| return U, S, VT |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment