Skip to content

Instantly share code, notes, and snippets.

@CrazyDaffodils
CrazyDaffodils / data_loader.py
Created September 1, 2020 20:36 — forked from kevinzakka/data_loader.py
Train, Validation and Test Split for torchvision Datasets
"""
Create train, valid, test iterators for CIFAR-10 [1].
Easily extended to MNIST, CIFAR-100 and Imagenet.
[1]: https://discuss.pytorch.org/t/feedback-on-pytorch-for-kaggle-competitions/2252/4
"""
import torch
import numpy as np
from umap import UMAP
import seaborn as sns
model = UMAP(n_neighbors = 40, min_dist = 0.4, n_components = 2)
umap = model.fit_transform(X_std)
umap_df = pd.DataFrame(data=umap, columns=['UMAP1','UMAP2']).join(labels)
palette = sns.color_palette("muted", n_colors=5)
sns.set_style("white")
sns.scatterplot(x='UMAP1',y='UMAP2',hue='Class',data=umap_df, palette=palette, linewidth=0.2, s=30,
alpha=1).set_title('UMAP')
from sklearn.manifold import TSNE
import seaborn as sns
#Visualize data using t-SNE.
model = TSNE(learning_rate = 10, n_components = 2, random_state=123, perplexity = 30)
tsne = model.fit_transform(X_std)
tsne_df = pd.DataFrame(data=tsne, columns=['t-SNE1','t-SNE2']).join(labels)
palette = sns.color_palette("muted", n_colors=5)
sns.set_style("white")
sns.scatterplot(x='t-SNE1',y='t-SNE2',hue='Class',data=tsne_df, palette=palette, linewidth=0.2, s=30, alpha=1).set_title('t-SNE')
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
pca_std = PCA().fit(X_std)
percent_variance=pca_std.explained_variance_ratio_*100
plt.figure()
plt.plot(np.cumsum(pca_std.explained_variance_ratio_))
plt.xlabel('Number of Components')
plt.ylabel('Variance (%)') #for each component
plt.show()
from sklearn.decomposition import PCA
import seaborn as sns
#Visualize data using Principal Component Analysis.
print("Principal Component Analysis (PCA)")
pca = PCA(n_components = 2).fit_transform(X_std)
pca_df = pd.DataFrame(data=pca, columns=['PC1','PC2']).join(labels)
palette = sns.color_palette("muted", n_colors=5)
sns.set_style("white")
sns.scatterplot(x='PC1',y='PC2',hue='Class',data=pca_df, palette=palette, linewidth=0.2, s=30, alpha=1).set_title('PCA')