import torch import torchvision import torch.utils.data import random import numpy as np from torch.utils.data import TensorDataset # https://github.com/galatolofederico/pytorch-balanced-batch/blob/master/sampler.py class BalancedBatchSampler(torch.utils.data.sampler.Sampler): def __init__(self, dataset, labels=None): self.labels = labels self.dataset = dict() self.balanced_max = 0 # Save all the indices for all the classes for idx in range(0, len(dataset)): label = self._get_label(dataset, idx) if label not in self.dataset: self.dataset[label] = list() self.dataset[label].append(idx) self.balanced_max = len(self.dataset[label]) \ if len(self.dataset[label]) > self.balanced_max else self.balanced_max # Oversample the classes with fewer elements than the max for label in self.dataset: while len(self.dataset[label]) < self.balanced_max: self.dataset[label].append(random.choice(self.dataset[label])) self.keys = list(self.dataset.keys()) self.currentkey = 0 self.indices = [-1]*len(self.keys) def __iter__(self): while self.indices[self.currentkey] < self.balanced_max - 1: self.indices[self.currentkey] += 1 yield self.dataset[self.keys[self.currentkey]][self.indices[self.currentkey]] self.currentkey = (self.currentkey + 1) % len(self.keys) self.indices = [-1]*len(self.keys) def _get_label(self, dataset, idx, labels = None): if self.labels is not None: return self.labels[idx].item() def __len__(self): return self.balanced_max*len(self.keys) # Create unbalanced data-set X = torch.Tensor(np.random.rand(100,2)) y = torch.Tensor(np.concatenate((np.ones(98), np.zeros(2)))) # Use sampler train_loader = torch.utils.data.DataLoader( TensorDataset(X,y), sampler=BalancedBatchSampler(X, y), batch_size=20) # Test for data, labels in train_loader: print(labels) # tensor([1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.]) break