{ "cells": [ { "cell_type": "code", "execution_count": 114, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "from torch.utils.data import random_split, TensorDataset, DataLoader, Dataset\n", "\n", "from astra.torch.models import EfficientNet, MLPClassifier, SIRENRegressor\n", "from astra.torch.data import load_cifar_10\n", "from astra.torch.utils import train_fn\n", "\n", "from tqdm.notebook import tqdm" ] }, { "cell_type": "code", "execution_count": 100, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] }, { "data": { "text/plain": [ "\n", "CIFAR-10 Dataset\n", "length of dataset: 60000\n", "shape of images: torch.Size([3, 32, 32])\n", "len of classes: 10\n", "classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n", "dtype of images: torch.float32\n", "dtype of labels: torch.int64\n", "range of image values: min=0.0, max=1.0\n", " " ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = load_cifar_10()\n", "data" ] }, { "cell_type": "code", "execution_count": 101, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([60000, 3, 32, 32]), torch.Size([60000]))" ] }, "execution_count": 101, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.data.shape, data.targets.shape" ] }, { "cell_type": "code", "execution_count": 102, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([10000, 3, 32, 32]) torch.Size([10000])\n", "torch.Size([10000, 3, 32, 32]) torch.Size([10000])\n", "torch.Size([40000, 3, 32, 32]) torch.Size([40000])\n" ] } ], "source": [ "all_idx = torch.randperm(len(data))\n", "test_idx = all_idx[:40000]\n", "ssl_idx = all_idx[40000:50000]\n", "train_idx = all_idx[50000:]\n", "\n", "ssl_data_x = data.data[ssl_idx]\n", "ssl_data_y = data.targets[ssl_idx]\n", "train_data_x = data.data[train_idx]\n", "train_data_y = data.targets[train_idx]\n", "test_data_x = data.data[test_idx]\n", "test_data_y = data.targets[test_idx]\n", "\n", "print(ssl_data_x.shape, ssl_data_y.shape)\n", "print(train_data_x.shape, train_data_y.shape)\n", "print(test_data_x.shape, test_data_y.shape)" ] }, { "cell_type": "code", "execution_count": 103, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 64/64 [00:06<00:00, 10.53it/s]\n" ] }, { "data": { "text/plain": [ "(64, 9)" ] }, "execution_count": 103, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from tqdm import trange\n", "import numpy as np\n", "import itertools\n", "from scipy.spatial.distance import cdist\n", "\n", "P_hat = np.array(list(itertools.permutations(list(range(9)), 9)))\n", "n = P_hat.shape[0]\n", "\n", "for i in trange(64):\n", " if i==0:\n", " j = np.random.randint(n)\n", " P = np.array(P_hat[j]).reshape([1,-1])\n", " else:\n", " P = np.concatenate([P,P_hat[j].reshape([1,-1])],axis=0)\n", "\n", " P_hat = np.delete(P_hat,j,axis=0)\n", " D = cdist(P,P_hat, metric='hamming').mean(axis=0).flatten()\n", " \n", " j = D.argmax()\n", "\n", "P.shape" ] }, { "cell_type": "code", "execution_count": 104, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[4, 8, 6, 7, 5, 1, 0, 3, 2],\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8],\n", " [1, 0, 3, 2, 6, 4, 5, 8, 7],\n", " [2, 3, 0, 1, 7, 6, 8, 4, 5],\n", " [3, 2, 1, 0, 8, 7, 4, 5, 6]])" ] }, "execution_count": 104, "metadata": {}, "output_type": "execute_result" } ], "source": [ "P[:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1% experiment" ] }, { "cell_type": "code", "execution_count": 105, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1000, 3, 32, 32]) torch.Size([1000])\n" ] } ], "source": [ "train_x = train_data_x[:1000]\n", "train_y = train_data_y[:1000]\n", "print(train_x.shape, train_y.shape)" ] }, { "cell_type": "code", "execution_count": 126, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([10000, 3, 32, 32]) (64, 9)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9431228fcde846b7b360acd96f1fe9a2", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/40 [00:00