Skip to content

Instantly share code, notes, and snippets.

@pdesgarets
Last active February 16, 2022 08:03
Show Gist options
  • Select an option

  • Save pdesgarets/e045bffc02f6563b400c28fbe2e88f04 to your computer and use it in GitHub Desktop.

Select an option

Save pdesgarets/e045bffc02f6563b400c28fbe2e88f04 to your computer and use it in GitHub Desktop.

Revisions

  1. pdesgarets revised this gist Feb 16, 2022. No changes.
  2. pdesgarets created this gist Feb 15, 2022.
    321 changes: 321 additions & 0 deletions tutorial_pytorch_gpu.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,321 @@
    {
    "cells": [
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "electric-prime",
    "metadata": {},
    "outputs": [],
    "source": [
    "import sys\n",
    "!{sys.executable} -m pip install -U torch numpy matplotlib torchvision ipywidgets jupyter widgetsnbextension"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "compressed-procedure",
    "metadata": {},
    "outputs": [],
    "source": [
    "from __future__ import print_function, division\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.optim import lr_scheduler\n",
    "import torch.backends.cudnn as cudnn\n",
    "import numpy as np\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "import copy\n",
    "\n",
    "cudnn.benchmark = True\n",
    "plt.ion() # interactive mode"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "periodic-barrel",
    "metadata": {},
    "outputs": [],
    "source": [
    "import urllib.request\n",
    "import zipfile\n",
    "filename, _ = urllib.request.urlretrieve(\"https://download.pytorch.org/tutorial/hymenoptera_data.zip\", \"data.zip\")\n",
    "zipfile.ZipFile(filename).extractall()"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "opponent-trust",
    "metadata": {},
    "outputs": [],
    "source": [
    "data_transforms = {\n",
    " 'train': transforms.Compose([\n",
    " transforms.RandomResizedCrop(224),\n",
    " transforms.RandomHorizontalFlip(),\n",
    " transforms.ToTensor(),\n",
    " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    " ]),\n",
    " 'val': transforms.Compose([\n",
    " transforms.Resize(256),\n",
    " transforms.CenterCrop(224),\n",
    " transforms.ToTensor(),\n",
    " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    " ]),\n",
    "}\n",
    "\n",
    "data_dir = 'hymenoptera_data'\n",
    "image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),\n",
    " data_transforms[x])\n",
    " for x in ['train', 'val']}\n",
    "dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,\n",
    " shuffle=True, num_workers=4)\n",
    " for x in ['train', 'val']}\n",
    "dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}\n",
    "class_names = image_datasets['train'].classes\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "threaded-johnson",
    "metadata": {},
    "outputs": [],
    "source": [
    "def imshow(inp, title=None):\n",
    " \"\"\"Imshow for Tensor.\"\"\"\n",
    " inp = inp.numpy().transpose((1, 2, 0))\n",
    " mean = np.array([0.485, 0.456, 0.406])\n",
    " std = np.array([0.229, 0.224, 0.225])\n",
    " inp = std * inp + mean\n",
    " inp = np.clip(inp, 0, 1)\n",
    " plt.imshow(inp)\n",
    " if title is not None:\n",
    " plt.title(title)\n",
    " plt.pause(0.001) # pause a bit so that plots are updated\n",
    "\n",
    "\n",
    "# Get a batch of training data\n",
    "inputs, classes = next(iter(dataloaders['train']))\n",
    "\n",
    "# Make a grid from batch\n",
    "out = torchvision.utils.make_grid(inputs)\n",
    "\n",
    "imshow(out, title=[class_names[x] for x in classes])\n"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "speaking-rescue",
    "metadata": {},
    "outputs": [],
    "source": [
    "def train_model(model, criterion, optimizer, scheduler, num_epochs=25):\n",
    " since = time.time()\n",
    "\n",
    " best_model_wts = copy.deepcopy(model.state_dict())\n",
    " best_acc = 0.0\n",
    "\n",
    " for epoch in range(num_epochs):\n",
    " print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
    " print('-' * 10)\n",
    "\n",
    " # Each epoch has a training and validation phase\n",
    " for phase in ['train', 'val']:\n",
    " if phase == 'train':\n",
    " model.train() # Set model to training mode\n",
    " else:\n",
    " model.eval() # Set model to evaluate mode\n",
    "\n",
    " running_loss = 0.0\n",
    " running_corrects = 0\n",
    "\n",
    " # Iterate over data.\n",
    " for inputs, labels in dataloaders[phase]:\n",
    " inputs = inputs.to(device)\n",
    " labels = labels.to(device)\n",
    "\n",
    " # zero the parameter gradients\n",
    " optimizer.zero_grad()\n",
    "\n",
    " # forward\n",
    " # track history if only in train\n",
    " with torch.set_grad_enabled(phase == 'train'):\n",
    " outputs = model(inputs)\n",
    " _, preds = torch.max(outputs, 1)\n",
    " loss = criterion(outputs, labels)\n",
    "\n",
    " # backward + optimize only if in training phase\n",
    " if phase == 'train':\n",
    " loss.backward()\n",
    " optimizer.step()\n",
    "\n",
    " # statistics\n",
    " running_loss += loss.item() * inputs.size(0)\n",
    " running_corrects += torch.sum(preds == labels.data)\n",
    " if phase == 'train':\n",
    " scheduler.step()\n",
    "\n",
    " epoch_loss = running_loss / dataset_sizes[phase]\n",
    " epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
    "\n",
    " print('{} Loss: {:.4f} Acc: {:.4f}'.format(\n",
    " phase, epoch_loss, epoch_acc))\n",
    "\n",
    " # deep copy the model\n",
    " if phase == 'val' and epoch_acc > best_acc:\n",
    " best_acc = epoch_acc\n",
    " best_model_wts = copy.deepcopy(model.state_dict())\n",
    "\n",
    " print()\n",
    "\n",
    " time_elapsed = time.time() - since\n",
    " print('Training complete in {:.0f}m {:.0f}s'.format(\n",
    " time_elapsed // 60, time_elapsed % 60))\n",
    " print('Best val Acc: {:4f}'.format(best_acc))\n",
    "\n",
    " # load best model weights\n",
    " model.load_state_dict(best_model_wts)\n",
    " return model"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "acting-brick",
    "metadata": {},
    "outputs": [],
    "source": [
    "def visualize_model(model, num_images=6):\n",
    " was_training = model.training\n",
    " model.eval()\n",
    " images_so_far = 0\n",
    " fig = plt.figure()\n",
    "\n",
    " with torch.no_grad():\n",
    " for i, (inputs, labels) in enumerate(dataloaders['val']):\n",
    " inputs = inputs.to(device)\n",
    " labels = labels.to(device)\n",
    "\n",
    " outputs = model(inputs)\n",
    " _, preds = torch.max(outputs, 1)\n",
    "\n",
    " for j in range(inputs.size()[0]):\n",
    " images_so_far += 1\n",
    " ax = plt.subplot(num_images//2, 2, images_so_far)\n",
    " ax.axis('off')\n",
    " ax.set_title('predicted: {}'.format(class_names[preds[j]]))\n",
    " imshow(inputs.cpu().data[j])\n",
    "\n",
    " if images_so_far == num_images:\n",
    " model.train(mode=was_training)\n",
    " return\n",
    " model.train(mode=was_training)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "surprising-allowance",
    "metadata": {},
    "outputs": [],
    "source": [
    "!jupyter nbextension enable --py widgetsnbextension"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "phantom-snowboard",
    "metadata": {},
    "outputs": [],
    "source": [
    "model_ft = models.resnet18(pretrained=True)\n",
    "num_ftrs = model_ft.fc.in_features\n",
    "# Here the size of each output sample is set to 2.\n",
    "# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).\n",
    "model_ft.fc = nn.Linear(num_ftrs, 2)\n",
    "\n",
    "model_ft = model_ft.to(device)\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# Observe that all parameters are being optimized\n",
    "optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)\n",
    "\n",
    "# Decay LR by a factor of 0.1 every 7 epochs\n",
    "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "enabling-manor",
    "metadata": {},
    "outputs": [],
    "source": [
    "model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,\n",
    " num_epochs=25)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "existing-shark",
    "metadata": {},
    "outputs": [],
    "source": [
    "device = torch.device(\"cpu\")\n",
    "model_ft = models.resnet18(pretrained=True)\n",
    "num_ftrs = model_ft.fc.in_features\n",
    "# Here the size of each output sample is set to 2.\n",
    "# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).\n",
    "model_ft.fc = nn.Linear(num_ftrs, 2)\n",
    "\n",
    "model_ft = model_ft.to(device)\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# Observe that all parameters are being optimized\n",
    "optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)\n",
    "\n",
    "# Decay LR by a factor of 0.1 every 7 epochs\n",
    "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)\n",
    "model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,\n",
    " num_epochs=25)"
    ]
    }
    ],
    "metadata": {
    "kernelspec": {
    "display_name": "Python 3",
    "language": "python",
    "name": "python3"
    },
    "language_info": {
    "codemirror_mode": {
    "name": "ipython",
    "version": 3
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.8.8"
    }
    },
    "nbformat": 4,
    "nbformat_minor": 5
    }