Skip to content

Instantly share code, notes, and snippets.

@manncodes
Created September 12, 2022 17:52
Show Gist options
  • Save manncodes/25e7ec92585ef6449ce7661f170d9b4c to your computer and use it in GitHub Desktop.
Save manncodes/25e7ec92585ef6449ce7661f170d9b4c to your computer and use it in GitHub Desktop.

Revisions

  1. manncodes created this gist Sep 12, 2022.
    48 changes: 48 additions & 0 deletions fast-tensor-dataloader.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,48 @@
    import torch

    class FastTensorDataLoader:
    """
    A DataLoader-like object for a set of tensors that can be much faster than
    TensorDataset + DataLoader because dataloader grabs individual indices of
    the dataset and calls cat (slow).
    Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6
    """
    def __init__(self, *tensors, batch_size=32, shuffle=False):
    """
    Initialize a FastTensorDataLoader.
    :param *tensors: tensors to store. Must have the same length @ dim 0.
    :param batch_size: batch size to load.
    :param shuffle: if True, shuffle the data *in-place* whenever an
    iterator is created out of this object.
    :returns: A FastTensorDataLoader.
    """
    assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
    self.tensors = tensors

    self.dataset_len = self.tensors[0].shape[0]
    self.batch_size = batch_size
    self.shuffle = shuffle

    # Calculate # batches
    n_batches, remainder = divmod(self.dataset_len, self.batch_size)
    if remainder > 0:
    n_batches += 1
    self.n_batches = n_batches
    def __iter__(self):
    if self.shuffle:
    r = torch.randperm(self.dataset_len)
    self.tensors = [t[r] for t in self.tensors]
    self.i = 0
    return self

    def __next__(self):
    if self.i >= self.dataset_len:
    raise StopIteration
    batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors)
    self.i += self.batch_size
    return batch

    def __len__(self):
    return self.n_batches