Skip to content

Instantly share code, notes, and snippets.

@ilkarman
Created October 15, 2019 10:11
Show Gist options
  • Save ilkarman/fe45c586ef10c4cb8e52b3a78b6ac854 to your computer and use it in GitHub Desktop.
Save ilkarman/fe45c586ef10c4cb8e52b3a78b6ac854 to your computer and use it in GitHub Desktop.

Revisions

  1. Ilia Karmanov created this gist Oct 15, 2019.
    59 changes: 59 additions & 0 deletions BalancedSampler.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,59 @@
    import torch
    import torchvision
    import torch.utils.data
    import random
    import numpy as np
    from torch.utils.data import TensorDataset

    # https://github.com/galatolofederico/pytorch-balanced-batch/blob/master/sampler.py
    class BalancedBatchSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, dataset, labels=None):
    self.labels = labels
    self.dataset = dict()
    self.balanced_max = 0
    # Save all the indices for all the classes
    for idx in range(0, len(dataset)):
    label = self._get_label(dataset, idx)
    if label not in self.dataset:
    self.dataset[label] = list()
    self.dataset[label].append(idx)
    self.balanced_max = len(self.dataset[label]) \
    if len(self.dataset[label]) > self.balanced_max else self.balanced_max

    # Oversample the classes with fewer elements than the max
    for label in self.dataset:
    while len(self.dataset[label]) < self.balanced_max:
    self.dataset[label].append(random.choice(self.dataset[label]))
    self.keys = list(self.dataset.keys())
    self.currentkey = 0
    self.indices = [-1]*len(self.keys)

    def __iter__(self):
    while self.indices[self.currentkey] < self.balanced_max - 1:
    self.indices[self.currentkey] += 1
    yield self.dataset[self.keys[self.currentkey]][self.indices[self.currentkey]]
    self.currentkey = (self.currentkey + 1) % len(self.keys)
    self.indices = [-1]*len(self.keys)

    def _get_label(self, dataset, idx, labels = None):
    if self.labels is not None:
    return self.labels[idx].item()

    def __len__(self):
    return self.balanced_max*len(self.keys)

    # Create unbalanced data-set
    X = torch.Tensor(np.random.rand(100,2))
    y = torch.Tensor(np.concatenate((np.ones(98), np.zeros(2))))

    # Use sampler
    train_loader = torch.utils.data.DataLoader(
    TensorDataset(X,y),
    sampler=BalancedBatchSampler(X, y),
    batch_size=20)

    # Test
    for data, labels in train_loader:
    print(labels)
    # tensor([1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.])
    break