In [114]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import random_split, TensorDataset, DataLoader, Dataset

from astra.torch.models import EfficientNet, MLPClassifier, SIRENRegressor
from astra.torch.data import load_cifar_10
from astra.torch.utils import train_fn

from tqdm.notebook import tqdm

In [100]:
data = load_cifar_10()
data

Files already downloaded and verified
Files already downloaded and verified



CIFAR-10 Dataset
length of dataset: 60000
shape of images: torch.Size([3, 32, 32])
len of classes: 10
classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
dtype of images: torch.float32
dtype of labels: torch.int64
range of image values: min=0.0, max=1.0
            

In [101]:
data.data.shape, data.targets.shape

(torch.Size([60000, 3, 32, 32]), torch.Size([60000]))

In [102]:
all_idx = torch.randperm(len(data))
test_idx = all_idx[:40000]
ssl_idx = all_idx[40000:50000]
train_idx = all_idx[50000:]

ssl_data_x = data.data[ssl_idx]
ssl_data_y = data.targets[ssl_idx]
train_data_x = data.data[train_idx]
train_data_y = data.targets[train_idx]
test_data_x = data.data[test_idx]
test_data_y = data.targets[test_idx]

print(ssl_data_x.shape, ssl_data_y.shape)
print(train_data_x.shape, train_data_y.shape)
print(test_data_x.shape, test_data_y.shape)

torch.Size([10000, 3, 32, 32]) torch.Size([10000])
torch.Size([10000, 3, 32, 32]) torch.Size([10000])
torch.Size([40000, 3, 32, 32]) torch.Size([40000])


In [103]:
from tqdm import trange
import numpy as np
import itertools
from scipy.spatial.distance import cdist

P_hat = np.array(list(itertools.permutations(list(range(9)), 9)))
n = P_hat.shape[0]

for i in trange(64):
    if i==0:
        j = np.random.randint(n)
        P = np.array(P_hat[j]).reshape([1,-1])
    else:
        P = np.concatenate([P,P_hat[j].reshape([1,-1])],axis=0)

    P_hat = np.delete(P_hat,j,axis=0)
    D = cdist(P,P_hat, metric='hamming').mean(axis=0).flatten()
    
    j = D.argmax()

P.shape

100%|██████████| 64/64 [00:06<00:00, 10.53it/s]


(64, 9)

In [104]:
P[:5]

array([[4, 8, 6, 7, 5, 1, 0, 3, 2],
       [0, 1, 2, 3, 4, 5, 6, 7, 8],
       [1, 0, 3, 2, 6, 4, 5, 8, 7],
       [2, 3, 0, 1, 7, 6, 8, 4, 5],
       [3, 2, 1, 0, 8, 7, 4, 5, 6]])

## 1% experiment

In [105]:
train_x = train_data_x[:1000]
train_y = train_data_y[:1000]
print(train_x.shape, train_y.shape)

torch.Size([1000, 3, 32, 32]) torch.Size([1000])


In [126]:
class OurDataSet(Dataset):
    def __init__(self, x, P):
        self.x = x
        self.P = P
        print(self.x.shape, self.P.shape)
        
        # split into 9 parts
        self.patches = []
        for i in [0, 10, 20]:
            for j in [0, 10, 20]:
                self.patches.append(self.x[...,i:i+10,j:j+10])
        # print(len(self.patches))

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        randidx = torch.randint(0, 64, (1,)).squeeze()
        perm = self.P[randidx]
        patches = []
        for perm_id in perm:
            patches.append(self.patches[perm_id][idx])
            # print(perm_id, patches[-1].shape)
        patch = torch.stack(patches, dim=0)
        return patch, randidx

model = EfficientNet().cuda()
aggregator = MLPClassifier(1280*9, [1024, 256], 64).cuda()
dataset = OurDataSet(ssl_data_x, P)
loader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)

epochs = 30

model.train()
aggregator.train()
optimizer = torch.optim.Adam(list(model.parameters())+list(aggregator.parameters()), lr=1e-3)
for epoch in range(epochs):
    pbar = tqdm(loader)
    losses = []
    for x, y in pbar:
        nine_outs = []
        for i in range(9):
            nine_outs.append(model(x[:,i,...].cuda()))
        nine_outs = torch.cat(nine_outs, dim=-1)
        nine_outs = aggregator(nine_outs)
        loss = F.cross_entropy(nine_outs, y.cuda())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        pbar.set_description(f'Epoch {epoch+1}/{epochs} Loss {loss.item():.4f}')
        losses.append(loss.item())
    print(f'Epoch {epoch+1}/{epochs} Loss {np.mean(losses):.4f}')

torch.Size([10000, 3, 32, 32]) (64, 9)


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 1/30 Loss 4.0134


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 2/30 Loss 3.1067


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 3/30 Loss 2.5024


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 4/30 Loss 2.1696


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 5/30 Loss 1.9577


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 6/30 Loss 1.7176


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 7/30 Loss 1.6413


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 8/30 Loss 1.5056


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 9/30 Loss 1.3762


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 10/30 Loss 1.2884


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 11/30 Loss 1.2311


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 12/30 Loss 1.1558


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 13/30 Loss 1.0807


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 14/30 Loss 1.0372


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 15/30 Loss 1.0466


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 16/30 Loss 1.0451


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 17/30 Loss 0.9513


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 18/30 Loss 0.9605


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 19/30 Loss 0.9428


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 20/30 Loss 0.9093


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 21/30 Loss 0.8614


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 22/30 Loss 0.8274


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 23/30 Loss 0.9073


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 24/30 Loss 0.9315


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 25/30 Loss 0.7980


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 26/30 Loss 0.7909


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 27/30 Loss 0.7929


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 28/30 Loss 0.8714


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 29/30 Loss 0.8106


  0%|          | 0/40 [00:00<?, ?it/s]

Epoch 30/30 Loss 0.6963


In [129]:
model.eval()
aggregator.eval()
pred_list = []
gt_list = []
for x, y in tqdm(loader):
    nine_outs = []
    for i in range(9):
        nine_outs.append(model(x[:,i,...].cuda()))
    nine_outs = torch.cat(nine_outs, dim=-1)
    nine_outs = aggregator(nine_outs)
    preds = nine_outs.argmax(dim=1)
    pred_list.append(preds.numpy(force=True))
    gt_list.append(y.numpy(force=True))
    
pred_list = np.concatenate(pred_list, axis=0)
gt_list = np.concatenate(gt_list, axis=0)
print((pred_list==gt_list).mean())

  0%|          | 0/40 [00:00<?, ?it/s]

0.8053
