-
-
Save skrish13/e356823ed80e042f7aeb778c64d75f40 to your computer and use it in GitHub Desktop.
Revisions
-
kevinzakka revised this gist
Aug 2, 2017 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -125,7 +125,7 @@ def get_test_loader(data_dir, - data_dir: path directory to the dataset. - batch_size: how many samples per batch to load. - shuffle: whether to shuffle the dataset after every epoch. - num_workers: number of subprocesses to use when loading the dataset. - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to True if using GPU. -
kevinzakka revised this gist
Aug 2, 2017 . 1 changed file with 1 addition and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,6 +1,7 @@ # This is an example for the CIFAR-10 dataset. # There's a function for creating a train and validation iterator. # There's also a function for creating a test iterator. # Inspired by https://discuss.pytorch.org/t/feedback-on-pytorch-for-kaggle-competitions/2252/4 from utils import plot_images -
kevinzakka revised this gist
Aug 2, 2017 . 2 changed files with 41 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -2,6 +2,8 @@ # There's a function for creating a train and validation iterator. # There's also a function for creating a test iterator. from utils import plot_images def get_train_valid_loader(data_dir, batch_size, augment, @@ -27,7 +29,6 @@ def get_train_valid_loader(data_dir, - random_seed: fix seed for reproducibility. - valid_size: percentage split of the training set used for the validation set. Should be a float in the range [0, 1]. - shuffle: whether to shuffle the train/validation indices. - show_sample: plot 9x9 sample grid of the dataset. - num_workers: number of subprocesses to use when loading the dataset. This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,39 @@ import matplotlib.pyplot as plt label_names = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' ] def plot_images(images, cls_true, cls_pred=None): assert len(images) == len(cls_true) == 9 # Create figure with sub-plots. fig, axes = plt.subplots(3, 3) for i, ax in enumerate(axes.flat): # plot the image ax.imshow(images[i, :, :, :], interpolation='spline16') # get its equivalent class name cls_true_name = label_names[cls_true[i]] if cls_pred is None: xlabel = "{0} ({1})".format(cls_true_name, cls_true[i]) else: cls_pred_name = label_names[cls_pred[i]] xlabel = "True: {0}\nPred: {1}".format(cls_true_name, cls_pred_name) ax.set_xlabel(xlabel) ax.set_xticks([]) ax.set_yticks([]) plt.show() -
kevinzakka created this gist
Aug 2, 2017 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,154 @@ # This is an example for the CIFAR-10 dataset. # There's a function for creating a train and validation iterator. # There's also a function for creating a test iterator. def get_train_valid_loader(data_dir, batch_size, augment, random_seed, valid_size=0.1, shuffle=True, show_sample=False, num_workers=4, pin_memory=False): """ Utility function for loading and returning train and valid multi-process iterators over the CIFAR-10 dataset. A sample 9x9 grid of the images can be optionally displayed. If using CUDA, num_workers should be set to 1 and pin_memory to True. Params ------ - data_dir: path directory to the dataset. - batch_size: how many samples per batch to load. - augment: whether to apply the data augmentation scheme mentioned in the paper. Only applied on the train split. - random_seed: fix seed for reproducibility. - valid_size: percentage split of the training set used for the validation set. Should be a float in the range [0, 1]. In the paper, this number is set to 0.1. - shuffle: whether to shuffle the train/validation indices. - show_sample: plot 9x9 sample grid of the dataset. - num_workers: number of subprocesses to use when loading the dataset. - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to True if using GPU. Returns ------- - train_loader: training set iterator. - valid_loader: validation set iterator. """ error_msg = "[!] valid_size should be in the range [0, 1]." assert ((valid_size >= 0) and (valid_size <= 1)), error_msg normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # define transforms valid_transform = transforms.Compose([ transforms.ToTensor(), normalize ]) if augment: train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) else: train_transform = transforms.Compose([ transforms.ToTensor(), normalize ]) # load the dataset train_dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform) valid_dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=valid_transform) num_train = len(train_dataset) indices = list(range(num_train)) split = int(np.floor(valid_size * num_train)) if shuffle == True: np.random.seed(random_seed) np.random.shuffle(indices) train_idx, valid_idx = indices[split:], indices[:split] train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetRandomSampler(valid_idx) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers, pin_memory=pin_memory) # visualize some images if show_sample: sample_loader = torch.utils.data.DataLoader(train_dataset, batch_size=9, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory) data_iter = iter(sample_loader) images, labels = data_iter.next() X = images.numpy() X = np.transpose(X, [0, 2, 3, 1]) plot_images(X, labels) return (train_loader, valid_loader) def get_test_loader(data_dir, batch_size, shuffle=True, num_workers=4, pin_memory=False): """ Utility function for loading and returning a multi-process test iterator over the CIFAR-10 dataset. If using CUDA, num_workers should be set to 1 and pin_memory to True. Params ------ - data_dir: path directory to the dataset. - batch_size: how many samples per batch to load. - shuffle: whether to shuffle the dataset after every epoch. - num_workers: number of subprocessed to use when loading the dataset. - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to True if using GPU. Returns ------- - data_loader: test set iterator. """ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # define transform transform = transforms.Compose([ transforms.ToTensor(), normalize ]) dataset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform) data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory) return data_loader