import torch from torch.utils.data import Dataset, DataLoader import numpy as np class MyDataset(Dataset): def __init__(self): x = np.random.rand(1000, 3) # 1000 3-dim samples self.x = [x[i].tolist() for i in range(1000)] y = np.random.randint(low=0, high=2, size=(1000,)) self.y = [y[i] for i in range(1000)] def __len__(self): return len(self.x) def __getitem__(self, idx): return self.x[idx], self.y[idx] def collate_fn(batch): data_list, label_list = [], [] for _data, _label in batch: data_list.append(_data) label_list.append(_label) return torch.Tensor(data_list), torch.LongTensor(label_list) if __name__ == "__main__": dataset = MyDataset() print(len(dataset)) print(dataset[-1]) dataloader = DataLoader(dataset, batch_size=3, shuffle=False, collate_fn=collate_fn) for data, label in dataloader: print(type(data)) print(data) print(type(label)) print(label) break