Created
June 8, 2020 09:06
-
-
Save dojeda/ad577aeab9e0111ce08aa663392c5359 to your computer and use it in GitHub Desktop.
Revisions
-
dojeda created this gist
Jun 8, 2020 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,106 @@ from functools import partial import matplotlib.pyplot as plt import numpy as np from pyriemann.utils.distance import distance_riemann @partial(np.vectorize, excluded=['ref']) def distance_riemann_v(c_xx, c_yy, ref): """ Vectorized version of distance_riemann projected in 2D Calculates the Riemannian distance between the reference matrix `ref` and a covariance matrix whose variance terms are `c_xx` and `c_yy`. Note that this function ignores the covariance term. Parameters ---------- c_xx: np.array Variances of the x channel. This array must have shape (n, ). c_yy: np.array Variances of the y channel. This array must have shape (n, ). ref: np.array Reference covariance matrix. This must be a matrix of size (2, 2). Returns ------- float: the Riemannian distance between [[c_xx, 0], [0, c_yy]] and `ref`. """ c = np.diag([c_xx, c_yy]) return distance_riemann(c, ref) @partial(np.vectorize, excluded=['ref']) def distance_euclid_v(c_xx, c_yy, ref): """ Vectorized version of Euclidean distance between 2D covariance matrices Calculates the Euclidean distance between the reference matrix `ref` and a covariance matrix whose variance terms are `c_xx` and `c_yy`. Note that this function ignores the covariance term. Parameters ---------- c_xx: np.array Variances of the x channel. This array must have shape (n, ). c_yy: np.array Variances of the y channel. This array must have shape (n, ). ref: np.array Reference covariance matrix. This must be a matrix of size (2, 2). Returns ------- float: the Riemannian distance between [[c_xx, 0], [0, c_yy]] and `ref`. """ return np.linalg.norm([ c_xx - ref[0, 0], c_yy - ref[1, 1] ]) def main(): # Reference matrix C = np.array([ [+3.1, -0.2], [-0.2, +5.4], ]) # Reference rejection distance d_reject = 1.5 # X and Y limits for plot vmin, vmax = 0.01, 30 # Mesh for 2D color/contour plot x = np.linspace(vmin, vmax, 50) y = x X, Y = np.meshgrid(x, y) # Distance values in mesh Zr = distance_riemann_v(X, Y, ref=C) Ze = distance_euclid_v(X, Y, ref=C) # Plotting fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12, 5)) cb = axs[0].contourf(X, Y, Zr, levels=20, vmin=0, vmax=Zr.max(), cmap='RdYlBu_r') axs[0].contour(X, Y, Zr, levels=[d_reject], colors='g') cbar = fig.colorbar(cb, ax=axs[0]) cbar.ax.set_ylabel('Riemannian distance to reference') axs[0].set_xlabel('Cxx') axs[0].set_ylabel('Cyy') axs[0].set_title('Riemann') cb = axs[1].contourf(X, Y, Ze, levels=20, vmin=0, vmax=Ze.max(), cmap='RdYlBu_r') axs[1].contour(X, Y, Ze, levels=[d_reject], colors='g') cbar = fig.colorbar(cb, ax=axs[1]) cbar.ax.set_ylabel('Euclidean distance to reference') axs[1].set_xlabel('Cxx') axs[1].set_ylabel('Cyy') axs[1].set_title('Euclid') plt.show() if __name__ == '__main__': main()