-
-
Save smrjans/a2fa5e5a0cacc2fc0d25b40853adc593 to your computer and use it in GitHub Desktop.
Fisher vectors with sklearn
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import pdb | |
| from sklearn.datasets import make_classification | |
| from sklearn.mixture import GMM | |
| def fisher_vector(xx, gmm): | |
| """Computes the Fisher vector on a set of descriptors. | |
| Parameters | |
| ---------- | |
| xx: array_like, shape (N, D) or (D, ) | |
| The set of descriptors | |
| gmm: instance of sklearn mixture.GMM object | |
| Gauassian mixture model of the descriptors. | |
| Returns | |
| ------- | |
| fv: array_like, shape (K + 2 * D * K, ) | |
| Fisher vector (derivatives with respect to the mixing weights, means | |
| and variances) of the given descriptors. | |
| Reference | |
| --------- | |
| J. Krapac, J. Verbeek, F. Jurie. Modeling Spatial Layout with Fisher | |
| Vectors for Image Categorization. In ICCV, 2011. | |
| http://hal.inria.fr/docs/00/61/94/03/PDF/final.r1.pdf | |
| """ | |
| xx = np.atleast_2d(xx) | |
| N = xx.shape[0] | |
| # Compute posterior probabilities. | |
| Q = gmm.predict_proba(xx) # NxK | |
| # Compute the sufficient statistics of descriptors. | |
| Q_sum = np.sum(Q, 0)[:, np.newaxis] / N | |
| Q_xx = np.dot(Q.T, xx) / N | |
| Q_xx_2 = np.dot(Q.T, xx ** 2) / N | |
| # Compute derivatives with respect to mixing weights, means and variances. | |
| d_pi = Q_sum.squeeze() - gmm.weights_ | |
| d_mu = Q_xx - Q_sum * gmm.means_ | |
| d_sigma = ( | |
| - Q_xx_2 | |
| - Q_sum * gmm.means_ ** 2 | |
| + Q_sum * gmm.covars_ | |
| + 2 * Q_xx * gmm.means_) | |
| # Merge derivatives into a vector. | |
| return np.hstack((d_pi, d_mu.flatten(), d_sigma.flatten())) | |
| def main(): | |
| # Short demo. | |
| K = 64 | |
| N = 1000 | |
| xx, _ = make_classification(n_samples=N) | |
| xx_tr, xx_te = xx[: -100], xx[-100: ] | |
| gmm = GMM(n_components=K, covariance_type='diag') | |
| gmm.fit(xx_tr) | |
| fv = fisher_vector(xx_te, gmm) | |
| pdb.set_trace() | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment