Created
January 26, 2024 06:42
-
-
Save patel-zeel/9c84d529fcf093bad02ca74c85c3a97f to your computer and use it in GitHub Desktop.
Jigsaw SSL
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 114, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torch.nn.functional as F\n", | |
| "\n", | |
| "from torch.utils.data import random_split, TensorDataset, DataLoader, Dataset\n", | |
| "\n", | |
| "from astra.torch.models import EfficientNet, MLPClassifier, SIRENRegressor\n", | |
| "from astra.torch.data import load_cifar_10\n", | |
| "from astra.torch.utils import train_fn\n", | |
| "\n", | |
| "from tqdm.notebook import tqdm" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 100, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Files already downloaded and verified\n", | |
| "Files already downloaded and verified\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "\n", | |
| "CIFAR-10 Dataset\n", | |
| "length of dataset: 60000\n", | |
| "shape of images: torch.Size([3, 32, 32])\n", | |
| "len of classes: 10\n", | |
| "classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n", | |
| "dtype of images: torch.float32\n", | |
| "dtype of labels: torch.int64\n", | |
| "range of image values: min=0.0, max=1.0\n", | |
| " " | |
| ] | |
| }, | |
| "execution_count": 100, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "data = load_cifar_10()\n", | |
| "data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 101, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(torch.Size([60000, 3, 32, 32]), torch.Size([60000]))" | |
| ] | |
| }, | |
| "execution_count": 101, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "data.data.shape, data.targets.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 102, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch.Size([10000, 3, 32, 32]) torch.Size([10000])\n", | |
| "torch.Size([10000, 3, 32, 32]) torch.Size([10000])\n", | |
| "torch.Size([40000, 3, 32, 32]) torch.Size([40000])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "all_idx = torch.randperm(len(data))\n", | |
| "test_idx = all_idx[:40000]\n", | |
| "ssl_idx = all_idx[40000:50000]\n", | |
| "train_idx = all_idx[50000:]\n", | |
| "\n", | |
| "ssl_data_x = data.data[ssl_idx]\n", | |
| "ssl_data_y = data.targets[ssl_idx]\n", | |
| "train_data_x = data.data[train_idx]\n", | |
| "train_data_y = data.targets[train_idx]\n", | |
| "test_data_x = data.data[test_idx]\n", | |
| "test_data_y = data.targets[test_idx]\n", | |
| "\n", | |
| "print(ssl_data_x.shape, ssl_data_y.shape)\n", | |
| "print(train_data_x.shape, train_data_y.shape)\n", | |
| "print(test_data_x.shape, test_data_y.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 103, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 64/64 [00:06<00:00, 10.53it/s]\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(64, 9)" | |
| ] | |
| }, | |
| "execution_count": 103, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "from tqdm import trange\n", | |
| "import numpy as np\n", | |
| "import itertools\n", | |
| "from scipy.spatial.distance import cdist\n", | |
| "\n", | |
| "P_hat = np.array(list(itertools.permutations(list(range(9)), 9)))\n", | |
| "n = P_hat.shape[0]\n", | |
| "\n", | |
| "for i in trange(64):\n", | |
| " if i==0:\n", | |
| " j = np.random.randint(n)\n", | |
| " P = np.array(P_hat[j]).reshape([1,-1])\n", | |
| " else:\n", | |
| " P = np.concatenate([P,P_hat[j].reshape([1,-1])],axis=0)\n", | |
| "\n", | |
| " P_hat = np.delete(P_hat,j,axis=0)\n", | |
| " D = cdist(P,P_hat, metric='hamming').mean(axis=0).flatten()\n", | |
| " \n", | |
| " j = D.argmax()\n", | |
| "\n", | |
| "P.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 104, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[4, 8, 6, 7, 5, 1, 0, 3, 2],\n", | |
| " [0, 1, 2, 3, 4, 5, 6, 7, 8],\n", | |
| " [1, 0, 3, 2, 6, 4, 5, 8, 7],\n", | |
| " [2, 3, 0, 1, 7, 6, 8, 4, 5],\n", | |
| " [3, 2, 1, 0, 8, 7, 4, 5, 6]])" | |
| ] | |
| }, | |
| "execution_count": 104, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "P[:5]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 1% experiment" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 105, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch.Size([1000, 3, 32, 32]) torch.Size([1000])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "train_x = train_data_x[:1000]\n", | |
| "train_y = train_data_y[:1000]\n", | |
| "print(train_x.shape, train_y.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 126, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch.Size([10000, 3, 32, 32]) (64, 9)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "9431228fcde846b7b360acd96f1fe9a2", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 1/30 Loss 4.0134\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "8fe8bf35a44f424ea5a62c6c74df5600", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 2/30 Loss 3.1067\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "82777580f1964212b84aae37d7f58196", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 3/30 Loss 2.5024\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "59f31212c9a34c5aadce02c906ea8828", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 4/30 Loss 2.1696\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "6dce75aa5e064011b549717d6c51f15e", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 5/30 Loss 1.9577\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "92636483e4bd41da995184febe0763b2", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 6/30 Loss 1.7176\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "f1d29a37008642a7a96c148680fd6f83", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 7/30 Loss 1.6413\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "9a8f8ae69d40491d9721b25d69323f6b", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 8/30 Loss 1.5056\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "1537a032609347b3a55a3b8118e9f17d", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 9/30 Loss 1.3762\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "2b225de4e79e4d1fbedccb3283d6dc76", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 10/30 Loss 1.2884\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "310e973e1cc747389d3a3488a501a116", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 11/30 Loss 1.2311\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "8b470927bc6f4a57afde45d94e253f3c", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 12/30 Loss 1.1558\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "d4e4b26da2704c909bc01a8902968d8b", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 13/30 Loss 1.0807\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "3ef7e87d6c794dbe864d60a785699117", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 14/30 Loss 1.0372\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "aab8add69634455f928d4bdba3cc5e2f", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 15/30 Loss 1.0466\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "e13713aec5a34867a3aa01d56811d9ab", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 16/30 Loss 1.0451\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "d1669ee2c49f48cda15337ee8b8e1e6d", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 17/30 Loss 0.9513\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "791e5c2dfbf74b60b709239a061189f9", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 18/30 Loss 0.9605\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "d2ecc6d254e54a55ab93f424509a3a25", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 19/30 Loss 0.9428\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "df523faea1764a078794901444781c21", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 20/30 Loss 0.9093\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "5ebcfcea8d984275827ffb0fe4a52f3c", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 21/30 Loss 0.8614\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "f6b4fef64fdb4a4cbe35eb926cd7ec66", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 22/30 Loss 0.8274\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "b3f9c57fb33140c794ee6d2ddf3e51e9", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 23/30 Loss 0.9073\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "44e47020d1714a97bff7658e27c8fb13", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 24/30 Loss 0.9315\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "a444c7b0b5ef4bffaa47765438e34513", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 25/30 Loss 0.7980\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "c9686467c234438296bc21841df24063", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 26/30 Loss 0.7909\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "bba77f75a77d4a8fa452c6e9a3c6ec35", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 27/30 Loss 0.7929\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "c811618a28ab44dea66a6e24ed9fbcb5", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 28/30 Loss 0.8714\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "059b7d573d394ec58cafc0f19f04767d", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 29/30 Loss 0.8106\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "b6d2efdd428b46b4bb1a00ceb649278e", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 30/30 Loss 0.6963\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "class OurDataSet(Dataset):\n", | |
| " def __init__(self, x, P):\n", | |
| " self.x = x\n", | |
| " self.P = P\n", | |
| " print(self.x.shape, self.P.shape)\n", | |
| " \n", | |
| " # split into 9 parts\n", | |
| " self.patches = []\n", | |
| " for i in [0, 10, 20]:\n", | |
| " for j in [0, 10, 20]:\n", | |
| " self.patches.append(self.x[...,i:i+10,j:j+10])\n", | |
| " # print(len(self.patches))\n", | |
| "\n", | |
| " def __len__(self):\n", | |
| " return len(self.x)\n", | |
| "\n", | |
| " def __getitem__(self, idx):\n", | |
| " randidx = torch.randint(0, 64, (1,)).squeeze()\n", | |
| " perm = self.P[randidx]\n", | |
| " patches = []\n", | |
| " for perm_id in perm:\n", | |
| " patches.append(self.patches[perm_id][idx])\n", | |
| " # print(perm_id, patches[-1].shape)\n", | |
| " patch = torch.stack(patches, dim=0)\n", | |
| " return patch, randidx\n", | |
| "\n", | |
| "model = EfficientNet().cuda()\n", | |
| "aggregator = MLPClassifier(1280*9, [1024, 256], 64).cuda()\n", | |
| "dataset = OurDataSet(ssl_data_x, P)\n", | |
| "loader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)\n", | |
| "\n", | |
| "epochs = 30\n", | |
| "\n", | |
| "model.train()\n", | |
| "aggregator.train()\n", | |
| "optimizer = torch.optim.Adam(list(model.parameters())+list(aggregator.parameters()), lr=1e-3)\n", | |
| "for epoch in range(epochs):\n", | |
| " pbar = tqdm(loader)\n", | |
| " losses = []\n", | |
| " for x, y in pbar:\n", | |
| " nine_outs = []\n", | |
| " for i in range(9):\n", | |
| " nine_outs.append(model(x[:,i,...].cuda()))\n", | |
| " nine_outs = torch.cat(nine_outs, dim=-1)\n", | |
| " nine_outs = aggregator(nine_outs)\n", | |
| " loss = F.cross_entropy(nine_outs, y.cuda())\n", | |
| " loss.backward()\n", | |
| " optimizer.step()\n", | |
| " optimizer.zero_grad()\n", | |
| "\n", | |
| " pbar.set_description(f'Epoch {epoch+1}/{epochs} Loss {loss.item():.4f}')\n", | |
| " losses.append(loss.item())\n", | |
| " print(f'Epoch {epoch+1}/{epochs} Loss {np.mean(losses):.4f}')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 129, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "3a6a1959c7934d1f8ec493426cd8d6c1", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/40 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "0.8053\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "model.eval()\n", | |
| "aggregator.eval()\n", | |
| "pred_list = []\n", | |
| "gt_list = []\n", | |
| "for x, y in tqdm(loader):\n", | |
| " nine_outs = []\n", | |
| " for i in range(9):\n", | |
| " nine_outs.append(model(x[:,i,...].cuda()))\n", | |
| " nine_outs = torch.cat(nine_outs, dim=-1)\n", | |
| " nine_outs = aggregator(nine_outs)\n", | |
| " preds = nine_outs.argmax(dim=1)\n", | |
| " pred_list.append(preds.numpy(force=True))\n", | |
| " gt_list.append(y.numpy(force=True))\n", | |
| " \n", | |
| "pred_list = np.concatenate(pred_list, axis=0)\n", | |
| "gt_list = np.concatenate(gt_list, axis=0)\n", | |
| "print((pred_list==gt_list).mean())" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "torch_gpu_py311", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.11.5" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment