Skip to content

Instantly share code, notes, and snippets.

@AutuanLiu
Forked from kevinzakka/data_loader.py
Created April 5, 2018 06:47
Show Gist options
  • Select an option

  • Save AutuanLiu/cdc506df20e0f7c40bc7b9c4b7b8185c to your computer and use it in GitHub Desktop.

Select an option

Save AutuanLiu/cdc506df20e0f7c40bc7b9c4b7b8185c to your computer and use it in GitHub Desktop.

Revisions

  1. @kevinzakka kevinzakka revised this gist Feb 16, 2018. 1 changed file with 6 additions and 4 deletions.
    10 changes: 6 additions & 4 deletions data_loader.py
    Original file line number Diff line number Diff line change
    @@ -1,7 +1,9 @@
    # This is an example for the CIFAR-10 dataset.
    # There's a function for creating a train and validation iterator.
    # There's also a function for creating a test iterator.
    # Inspired by https://discuss.pytorch.org/t/feedback-on-pytorch-for-kaggle-competitions/2252/4
    """
    Create train, valid, test iterators for CIFAR-10 [1].
    Easily extended to MNIST, CIFAR-100 and Imagenet.
    [1]: https://discuss.pytorch.org/t/feedback-on-pytorch-for-kaggle-competitions/2252/4
    """

    import torch
    import numpy as np
  2. @kevinzakka kevinzakka revised this gist Feb 16, 2018. 2 changed files with 68 additions and 55 deletions.
    102 changes: 57 additions & 45 deletions data_loader.py
    Original file line number Diff line number Diff line change
    @@ -3,7 +3,14 @@
    # There's also a function for creating a test iterator.
    # Inspired by https://discuss.pytorch.org/t/feedback-on-pytorch-for-kaggle-competitions/2252/4

    import torch
    import numpy as np

    from utils import plot_images
    from torchvision import datasets
    from torchvision import transforms
    from torch.utils.data.sampler import SubsetRandomSampler


    def get_train_valid_loader(data_dir,
    batch_size,
    @@ -15,8 +22,8 @@ def get_train_valid_loader(data_dir,
    num_workers=4,
    pin_memory=False):
    """
    Utility function for loading and returning train and valid
    multi-process iterators over the CIFAR-10 dataset. A sample
    Utility function for loading and returning train and valid
    multi-process iterators over the CIFAR-10 dataset. A sample
    9x9 grid of the images can be optionally displayed.
    If using CUDA, num_workers should be set to 1 and pin_memory to True.
    @@ -44,78 +51,82 @@ def get_train_valid_loader(data_dir,
    error_msg = "[!] valid_size should be in the range [0, 1]."
    assert ((valid_size >= 0) and (valid_size <= 1)), error_msg

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])
    normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225],
    )

    # define transforms
    valid_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
    ])
    normalize,
    ])
    if augment:
    train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
    normalize,
    ])
    else:
    train_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
    normalize,
    ])

    # load the dataset
    train_dataset = datasets.CIFAR10(root=data_dir, train=True,
    download=True, transform=train_transform)
    train_dataset = datasets.CIFAR10(
    root=data_dir, train=True,
    download=True, transform=train_transform,
    )

    valid_dataset = datasets.CIFAR10(root=data_dir, train=True,
    download=True, transform=valid_transform)
    valid_dataset = datasets.CIFAR10(
    root=data_dir, train=True,
    download=True, transform=valid_transform,
    )

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle == True:
    if shuffle:
    np.random.seed(random_seed)
    np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=batch_size, sampler=train_sampler,
    num_workers=num_workers, pin_memory=pin_memory)

    valid_loader = torch.utils.data.DataLoader(valid_dataset,
    batch_size=batch_size, sampler=valid_sampler,
    num_workers=num_workers, pin_memory=pin_memory)

    train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, sampler=train_sampler,
    num_workers=num_workers, pin_memory=pin_memory,
    )
    valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=batch_size, sampler=valid_sampler,
    num_workers=num_workers, pin_memory=pin_memory,
    )

    # visualize some images
    if show_sample:
    sample_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=9,
    shuffle=shuffle,
    num_workers=num_workers,
    pin_memory=pin_memory)
    sample_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=9, shuffle=shuffle,
    num_workers=num_workers, pin_memory=pin_memory,
    )
    data_iter = iter(sample_loader)
    images, labels = data_iter.next()
    X = images.numpy()
    X = np.transpose(X, [0, 2, 3, 1])
    X = images.numpy().transpose([0, 2, 3, 1])
    plot_images(X, labels)

    return (train_loader, valid_loader)

    def get_test_loader(data_dir,


    def get_test_loader(data_dir,
    batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=False):
    """
    Utility function for loading and returning a multi-process
    Utility function for loading and returning a multi-process
    test iterator over the CIFAR-10 dataset.
    If using CUDA, num_workers should be set to 1 and pin_memory to True.
    @@ -133,24 +144,25 @@ def get_test_loader(data_dir,
    -------
    - data_loader: test set iterator.
    """
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])
    normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225],
    )

    # define transform
    transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
    normalize,
    ])

    dataset = datasets.CIFAR10(root=data_dir,
    train=False,
    download=True,
    transform=transform)
    dataset = datasets.CIFAR10(
    root=data_dir, train=False,
    download=True, transform=transform,
    )

    data_loader = torch.utils.data.DataLoader(dataset,
    batch_size=batch_size,
    shuffle=shuffle,
    num_workers=num_workers,
    pin_memory=pin_memory)
    data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=shuffle,
    num_workers=num_workers, pin_memory=pin_memory,
    )

    return data_loader
    return data_loader
    21 changes: 11 additions & 10 deletions utils.py
    Original file line number Diff line number Diff line change
    @@ -13,27 +13,28 @@
    'truck'
    ]

    def plot_images(images, cls_true, cls_pred=None):

    assert len(images) == len(cls_true) == 9

    # Create figure with sub-plots.
    def plot_images(images, cls_true, cls_pred=None):
    """
    Adapted from https://github.com/Hvass-Labs/TensorFlow-Tutorials/
    """
    fig, axes = plt.subplots(3, 3)

    for i, ax in enumerate(axes.flat):
    # plot the image
    # plot img
    ax.imshow(images[i, :, :, :], interpolation='spline16')
    # get its equivalent class name

    # show true & predicted classes
    cls_true_name = label_names[cls_true[i]]

    if cls_pred is None:
    xlabel = "{0} ({1})".format(cls_true_name, cls_true[i])
    else:
    cls_pred_name = label_names[cls_pred[i]]
    xlabel = "True: {0}\nPred: {1}".format(cls_true_name, cls_pred_name)

    xlabel = "True: {0}\nPred: {1}".format(
    cls_true_name, cls_pred_name
    )
    ax.set_xlabel(xlabel)
    ax.set_xticks([])
    ax.set_yticks([])

    plt.show()
    plt.show()
  3. @kevinzakka kevinzakka revised this gist Aug 2, 2017. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion data_loader.py
    Original file line number Diff line number Diff line change
    @@ -125,7 +125,7 @@ def get_test_loader(data_dir,
    - data_dir: path directory to the dataset.
    - batch_size: how many samples per batch to load.
    - shuffle: whether to shuffle the dataset after every epoch.
    - num_workers: number of subprocessed to use when loading the dataset.
    - num_workers: number of subprocesses to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
    True if using GPU.
  4. @kevinzakka kevinzakka revised this gist Aug 2, 2017. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions data_loader.py
    Original file line number Diff line number Diff line change
    @@ -1,6 +1,7 @@
    # This is an example for the CIFAR-10 dataset.
    # There's a function for creating a train and validation iterator.
    # There's also a function for creating a test iterator.
    # Inspired by https://discuss.pytorch.org/t/feedback-on-pytorch-for-kaggle-competitions/2252/4

    from utils import plot_images

  5. @kevinzakka kevinzakka revised this gist Aug 2, 2017. 2 changed files with 41 additions and 1 deletion.
    3 changes: 2 additions & 1 deletion data_loader.py
    Original file line number Diff line number Diff line change
    @@ -2,6 +2,8 @@
    # There's a function for creating a train and validation iterator.
    # There's also a function for creating a test iterator.

    from utils import plot_images

    def get_train_valid_loader(data_dir,
    batch_size,
    augment,
    @@ -27,7 +29,6 @@ def get_train_valid_loader(data_dir,
    - random_seed: fix seed for reproducibility.
    - valid_size: percentage split of the training set used for
    the validation set. Should be a float in the range [0, 1].
    In the paper, this number is set to 0.1.
    - shuffle: whether to shuffle the train/validation indices.
    - show_sample: plot 9x9 sample grid of the dataset.
    - num_workers: number of subprocesses to use when loading the dataset.
    39 changes: 39 additions & 0 deletions utils.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,39 @@
    import matplotlib.pyplot as plt

    label_names = [
    'airplane',
    'automobile',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck'
    ]

    def plot_images(images, cls_true, cls_pred=None):

    assert len(images) == len(cls_true) == 9

    # Create figure with sub-plots.
    fig, axes = plt.subplots(3, 3)

    for i, ax in enumerate(axes.flat):
    # plot the image
    ax.imshow(images[i, :, :, :], interpolation='spline16')
    # get its equivalent class name
    cls_true_name = label_names[cls_true[i]]

    if cls_pred is None:
    xlabel = "{0} ({1})".format(cls_true_name, cls_true[i])
    else:
    cls_pred_name = label_names[cls_pred[i]]
    xlabel = "True: {0}\nPred: {1}".format(cls_true_name, cls_pred_name)

    ax.set_xlabel(xlabel)
    ax.set_xticks([])
    ax.set_yticks([])

    plt.show()
  6. @kevinzakka kevinzakka created this gist Aug 2, 2017.
    154 changes: 154 additions & 0 deletions data_loader.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,154 @@
    # This is an example for the CIFAR-10 dataset.
    # There's a function for creating a train and validation iterator.
    # There's also a function for creating a test iterator.

    def get_train_valid_loader(data_dir,
    batch_size,
    augment,
    random_seed,
    valid_size=0.1,
    shuffle=True,
    show_sample=False,
    num_workers=4,
    pin_memory=False):
    """
    Utility function for loading and returning train and valid
    multi-process iterators over the CIFAR-10 dataset. A sample
    9x9 grid of the images can be optionally displayed.
    If using CUDA, num_workers should be set to 1 and pin_memory to True.
    Params
    ------
    - data_dir: path directory to the dataset.
    - batch_size: how many samples per batch to load.
    - augment: whether to apply the data augmentation scheme
    mentioned in the paper. Only applied on the train split.
    - random_seed: fix seed for reproducibility.
    - valid_size: percentage split of the training set used for
    the validation set. Should be a float in the range [0, 1].
    In the paper, this number is set to 0.1.
    - shuffle: whether to shuffle the train/validation indices.
    - show_sample: plot 9x9 sample grid of the dataset.
    - num_workers: number of subprocesses to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
    True if using GPU.
    Returns
    -------
    - train_loader: training set iterator.
    - valid_loader: validation set iterator.
    """
    error_msg = "[!] valid_size should be in the range [0, 1]."
    assert ((valid_size >= 0) and (valid_size <= 1)), error_msg

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])

    # define transforms
    valid_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
    ])
    if augment:
    train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
    ])
    else:
    train_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
    ])

    # load the dataset
    train_dataset = datasets.CIFAR10(root=data_dir, train=True,
    download=True, transform=train_transform)

    valid_dataset = datasets.CIFAR10(root=data_dir, train=True,
    download=True, transform=valid_transform)

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle == True:
    np.random.seed(random_seed)
    np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=batch_size, sampler=train_sampler,
    num_workers=num_workers, pin_memory=pin_memory)

    valid_loader = torch.utils.data.DataLoader(valid_dataset,
    batch_size=batch_size, sampler=valid_sampler,
    num_workers=num_workers, pin_memory=pin_memory)


    # visualize some images
    if show_sample:
    sample_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=9,
    shuffle=shuffle,
    num_workers=num_workers,
    pin_memory=pin_memory)
    data_iter = iter(sample_loader)
    images, labels = data_iter.next()
    X = images.numpy()
    X = np.transpose(X, [0, 2, 3, 1])
    plot_images(X, labels)

    return (train_loader, valid_loader)

    def get_test_loader(data_dir,
    batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=False):
    """
    Utility function for loading and returning a multi-process
    test iterator over the CIFAR-10 dataset.
    If using CUDA, num_workers should be set to 1 and pin_memory to True.
    Params
    ------
    - data_dir: path directory to the dataset.
    - batch_size: how many samples per batch to load.
    - shuffle: whether to shuffle the dataset after every epoch.
    - num_workers: number of subprocessed to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
    True if using GPU.
    Returns
    -------
    - data_loader: test set iterator.
    """
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])

    # define transform
    transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
    ])

    dataset = datasets.CIFAR10(root=data_dir,
    train=False,
    download=True,
    transform=transform)

    data_loader = torch.utils.data.DataLoader(dataset,
    batch_size=batch_size,
    shuffle=shuffle,
    num_workers=num_workers,
    pin_memory=pin_memory)

    return data_loader