Skip to content

Instantly share code, notes, and snippets.

@patel-zeel
Created January 26, 2024 06:42
Show Gist options
  • Select an option

  • Save patel-zeel/9c84d529fcf093bad02ca74c85c3a97f to your computer and use it in GitHub Desktop.

Select an option

Save patel-zeel/9c84d529fcf093bad02ca74c85c3a97f to your computer and use it in GitHub Desktop.
Jigsaw SSL
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 114,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"from torch.utils.data import random_split, TensorDataset, DataLoader, Dataset\n",
"\n",
"from astra.torch.models import EfficientNet, MLPClassifier, SIRENRegressor\n",
"from astra.torch.data import load_cifar_10\n",
"from astra.torch.utils import train_fn\n",
"\n",
"from tqdm.notebook import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
},
{
"data": {
"text/plain": [
"\n",
"CIFAR-10 Dataset\n",
"length of dataset: 60000\n",
"shape of images: torch.Size([3, 32, 32])\n",
"len of classes: 10\n",
"classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n",
"dtype of images: torch.float32\n",
"dtype of labels: torch.int64\n",
"range of image values: min=0.0, max=1.0\n",
" "
]
},
"execution_count": 100,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = load_cifar_10()\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([60000, 3, 32, 32]), torch.Size([60000]))"
]
},
"execution_count": 101,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.data.shape, data.targets.shape"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([10000, 3, 32, 32]) torch.Size([10000])\n",
"torch.Size([10000, 3, 32, 32]) torch.Size([10000])\n",
"torch.Size([40000, 3, 32, 32]) torch.Size([40000])\n"
]
}
],
"source": [
"all_idx = torch.randperm(len(data))\n",
"test_idx = all_idx[:40000]\n",
"ssl_idx = all_idx[40000:50000]\n",
"train_idx = all_idx[50000:]\n",
"\n",
"ssl_data_x = data.data[ssl_idx]\n",
"ssl_data_y = data.targets[ssl_idx]\n",
"train_data_x = data.data[train_idx]\n",
"train_data_y = data.targets[train_idx]\n",
"test_data_x = data.data[test_idx]\n",
"test_data_y = data.targets[test_idx]\n",
"\n",
"print(ssl_data_x.shape, ssl_data_y.shape)\n",
"print(train_data_x.shape, train_data_y.shape)\n",
"print(test_data_x.shape, test_data_y.shape)"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 64/64 [00:06<00:00, 10.53it/s]\n"
]
},
{
"data": {
"text/plain": [
"(64, 9)"
]
},
"execution_count": 103,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from tqdm import trange\n",
"import numpy as np\n",
"import itertools\n",
"from scipy.spatial.distance import cdist\n",
"\n",
"P_hat = np.array(list(itertools.permutations(list(range(9)), 9)))\n",
"n = P_hat.shape[0]\n",
"\n",
"for i in trange(64):\n",
" if i==0:\n",
" j = np.random.randint(n)\n",
" P = np.array(P_hat[j]).reshape([1,-1])\n",
" else:\n",
" P = np.concatenate([P,P_hat[j].reshape([1,-1])],axis=0)\n",
"\n",
" P_hat = np.delete(P_hat,j,axis=0)\n",
" D = cdist(P,P_hat, metric='hamming').mean(axis=0).flatten()\n",
" \n",
" j = D.argmax()\n",
"\n",
"P.shape"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[4, 8, 6, 7, 5, 1, 0, 3, 2],\n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8],\n",
" [1, 0, 3, 2, 6, 4, 5, 8, 7],\n",
" [2, 3, 0, 1, 7, 6, 8, 4, 5],\n",
" [3, 2, 1, 0, 8, 7, 4, 5, 6]])"
]
},
"execution_count": 104,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"P[:5]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1% experiment"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1000, 3, 32, 32]) torch.Size([1000])\n"
]
}
],
"source": [
"train_x = train_data_x[:1000]\n",
"train_y = train_data_y[:1000]\n",
"print(train_x.shape, train_y.shape)"
]
},
{
"cell_type": "code",
"execution_count": 126,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([10000, 3, 32, 32]) (64, 9)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9431228fcde846b7b360acd96f1fe9a2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/30 Loss 4.0134\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8fe8bf35a44f424ea5a62c6c74df5600",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/30 Loss 3.1067\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "82777580f1964212b84aae37d7f58196",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/30 Loss 2.5024\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "59f31212c9a34c5aadce02c906ea8828",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/30 Loss 2.1696\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6dce75aa5e064011b549717d6c51f15e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/30 Loss 1.9577\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "92636483e4bd41da995184febe0763b2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 6/30 Loss 1.7176\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f1d29a37008642a7a96c148680fd6f83",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 7/30 Loss 1.6413\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9a8f8ae69d40491d9721b25d69323f6b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 8/30 Loss 1.5056\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1537a032609347b3a55a3b8118e9f17d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 9/30 Loss 1.3762\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2b225de4e79e4d1fbedccb3283d6dc76",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10/30 Loss 1.2884\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "310e973e1cc747389d3a3488a501a116",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 11/30 Loss 1.2311\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8b470927bc6f4a57afde45d94e253f3c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 12/30 Loss 1.1558\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d4e4b26da2704c909bc01a8902968d8b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 13/30 Loss 1.0807\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3ef7e87d6c794dbe864d60a785699117",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 14/30 Loss 1.0372\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "aab8add69634455f928d4bdba3cc5e2f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 15/30 Loss 1.0466\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e13713aec5a34867a3aa01d56811d9ab",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 16/30 Loss 1.0451\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d1669ee2c49f48cda15337ee8b8e1e6d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 17/30 Loss 0.9513\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "791e5c2dfbf74b60b709239a061189f9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 18/30 Loss 0.9605\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d2ecc6d254e54a55ab93f424509a3a25",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 19/30 Loss 0.9428\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "df523faea1764a078794901444781c21",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 20/30 Loss 0.9093\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5ebcfcea8d984275827ffb0fe4a52f3c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 21/30 Loss 0.8614\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f6b4fef64fdb4a4cbe35eb926cd7ec66",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 22/30 Loss 0.8274\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b3f9c57fb33140c794ee6d2ddf3e51e9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 23/30 Loss 0.9073\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "44e47020d1714a97bff7658e27c8fb13",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 24/30 Loss 0.9315\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a444c7b0b5ef4bffaa47765438e34513",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 25/30 Loss 0.7980\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c9686467c234438296bc21841df24063",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 26/30 Loss 0.7909\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bba77f75a77d4a8fa452c6e9a3c6ec35",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 27/30 Loss 0.7929\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c811618a28ab44dea66a6e24ed9fbcb5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 28/30 Loss 0.8714\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "059b7d573d394ec58cafc0f19f04767d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 29/30 Loss 0.8106\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b6d2efdd428b46b4bb1a00ceb649278e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 30/30 Loss 0.6963\n"
]
}
],
"source": [
"class OurDataSet(Dataset):\n",
" def __init__(self, x, P):\n",
" self.x = x\n",
" self.P = P\n",
" print(self.x.shape, self.P.shape)\n",
" \n",
" # split into 9 parts\n",
" self.patches = []\n",
" for i in [0, 10, 20]:\n",
" for j in [0, 10, 20]:\n",
" self.patches.append(self.x[...,i:i+10,j:j+10])\n",
" # print(len(self.patches))\n",
"\n",
" def __len__(self):\n",
" return len(self.x)\n",
"\n",
" def __getitem__(self, idx):\n",
" randidx = torch.randint(0, 64, (1,)).squeeze()\n",
" perm = self.P[randidx]\n",
" patches = []\n",
" for perm_id in perm:\n",
" patches.append(self.patches[perm_id][idx])\n",
" # print(perm_id, patches[-1].shape)\n",
" patch = torch.stack(patches, dim=0)\n",
" return patch, randidx\n",
"\n",
"model = EfficientNet().cuda()\n",
"aggregator = MLPClassifier(1280*9, [1024, 256], 64).cuda()\n",
"dataset = OurDataSet(ssl_data_x, P)\n",
"loader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)\n",
"\n",
"epochs = 30\n",
"\n",
"model.train()\n",
"aggregator.train()\n",
"optimizer = torch.optim.Adam(list(model.parameters())+list(aggregator.parameters()), lr=1e-3)\n",
"for epoch in range(epochs):\n",
" pbar = tqdm(loader)\n",
" losses = []\n",
" for x, y in pbar:\n",
" nine_outs = []\n",
" for i in range(9):\n",
" nine_outs.append(model(x[:,i,...].cuda()))\n",
" nine_outs = torch.cat(nine_outs, dim=-1)\n",
" nine_outs = aggregator(nine_outs)\n",
" loss = F.cross_entropy(nine_outs, y.cuda())\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" pbar.set_description(f'Epoch {epoch+1}/{epochs} Loss {loss.item():.4f}')\n",
" losses.append(loss.item())\n",
" print(f'Epoch {epoch+1}/{epochs} Loss {np.mean(losses):.4f}')"
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3a6a1959c7934d1f8ec493426cd8d6c1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/40 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.8053\n"
]
}
],
"source": [
"model.eval()\n",
"aggregator.eval()\n",
"pred_list = []\n",
"gt_list = []\n",
"for x, y in tqdm(loader):\n",
" nine_outs = []\n",
" for i in range(9):\n",
" nine_outs.append(model(x[:,i,...].cuda()))\n",
" nine_outs = torch.cat(nine_outs, dim=-1)\n",
" nine_outs = aggregator(nine_outs)\n",
" preds = nine_outs.argmax(dim=1)\n",
" pred_list.append(preds.numpy(force=True))\n",
" gt_list.append(y.numpy(force=True))\n",
" \n",
"pred_list = np.concatenate(pred_list, axis=0)\n",
"gt_list = np.concatenate(gt_list, axis=0)\n",
"print((pred_list==gt_list).mean())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torch_gpu_py311",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment