Skip to content

Instantly share code, notes, and snippets.

@pmadhyastha
Created June 1, 2014 17:11
Show Gist options
  • Save pmadhyastha/7540653ff4fbc35155dd to your computer and use it in GitHub Desktop.
Save pmadhyastha/7540653ff4fbc35155dd to your computer and use it in GitHub Desktop.
Computing svd of a matrix using lapack's dgesdd available here: http://www.netlib.org/lapack/explore-html/db/db4/dgesdd_8f.html
#!/usr/bin/env python
from numpy cimport int
cdef extern void dgesdd_(char *jobz, int *m, int *n,
double a[], int *lda, double s[], double u[],
int *ldu, double vt[], int *ldvt, double work[],
int *lwork, int iwork[], int *info)
import numpy as np
import pyximport
pyximport.install(setup_args = {'options' :
{'build_ext' :
{'libraries' : 'lapack',
'include_dirs' : np.get_include(),
}}})
from lsvd import fastsvd
x = np.random.rand(100, 100)
u, s, vt = fastsvd(x) ## gives wrong answer i.e., np.dot(u, np.dot(np.diag(s), vt) =! x
#when x is not a square matrix it gives errors:
x = np.random.rand(100, 200)
u, s, vt = fastsvd(x) ###THIS GIVES AN ERROR:
#
.pyxbld/temp.linux-x86_64-2.7/pyrex/lsvd.c:1834)()
#/usr/lib/python2.7/dist-packages/numpy/core/numeric.pyc in asfortranarray(a, dtype)
# 582
# 583 """
#--> 584 return array(a, dtype, copy=False, order='F', ndmin=1)
# 585
# 586 def require(a, dtype=None, requirements=None):
#ValueError: On entry to DGESDD parameter number 10 had an illegal value
#!/usr/bin/env python
cimport cython
import numpy as np
cimport numpy as np
from numpy.compat import asbytes
from numpy import int
from clapack cimport dgesdd_
@cython.boundscheck(False)
cdef inline int int_max(int a, int b): return a if a >= b else b
@cython.boundscheck(False)
cdef inline int int_min(int a, int b): return a if a <= b else b
@cython.boundscheck(False)
def fastsvd(mat):
cdef np.ndarray[double, ndim=2] a = mat
cdef int M = a.shape[0]
cdef int N = a.shape[1]
cdef np.ndarray[double, ndim=1] s = np.empty(M, dtype=np.double, order='F')
cdef np.ndarray[double, ndim=2] u = np.empty((M, M), dtype=np.double, order='F')
cdef np.ndarray[double, ndim=2] vt = np.empty((N, N), dtype=np.double, order='F')
cdef int info = 0
cdef char jobz = 'A'
cdef int lwork = int_min(M,N)*(6+4*int_min(M,N))+int_max(M,N)
cdef np.ndarray[double, ndim=1] work = np.zeros((lwork,), dtype=np.double, order='F')
cdef int tiwork = 8*int_min(M,N)
cdef np.ndarray iwork = np.zeros((tiwork,), dtype=int, order='F')
cdef int ldvt = M
cdef int ldu = M
cdef int lda = M
dgesdd_(&jobz, &M, &N, <double *> a.data, &lda,
<double *> s.data, <double *> u.data, &ldu, <double *> vt.data,
&ldvt, <double *> work.data, &lwork, <int *> iwork.data, &info)
U = np.array(u)
S = np.asfortranarray(s)
VT = np.array(vt)
return U, S, VT
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment