Skip to content

Instantly share code, notes, and snippets.

@DENGARDEN
Created December 22, 2022 06:11
Show Gist options
  • Select an option

  • Save DENGARDEN/4ab8edd701c5c369a9e5e8e815b0bd2c to your computer and use it in GitHub Desktop.

Select an option

Save DENGARDEN/4ab8edd701c5c369a9e5e8e815b0bd2c to your computer and use it in GitHub Desktop.

Revisions

  1. DENGARDEN revised this gist Dec 22, 2022. 1 changed file with 12 additions and 1 deletion.
    13 changes: 12 additions & 1 deletion lab01_pytorch.ipynb
    Original file line number Diff line number Diff line change
    @@ -1,5 +1,15 @@
    {
    "cells": [
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "view-in-github",
    "colab_type": "text"
    },
    "source": [
    "<a href=\"https://colab.research.google.com/gist/DENGARDEN/4ab8edd701c5c369a9e5e8e815b0bd2c/lab01_pytorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    @@ -2552,7 +2562,8 @@
    "name": "lab01_pytorch.ipynb",
    "private_outputs": true,
    "provenance": [],
    "toc_visible": true
    "toc_visible": true,
    "include_colab_link": true
    },
    "gpuClass": "standard",
    "kernelspec": {
  2. DENGARDEN created this gist Dec 22, 2022.
    2,578 changes: 2,578 additions & 0 deletions lab01_pytorch.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,2578 @@
    {
    "cells": [
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "FlH0lCOttCs5"
    },
    "source": [
    "<img src=\"https://fsdl.me/logo-720-dark-horizontal\">"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "ZUPRHaeetRnT"
    },
    "source": [
    "\n",
    "\n",
    "```\n",
    "# This is formatted as code\n",
    "```\n",
    "\n",
    "# Lab 01: Deep Neural Networks in PyTorch"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "bry3Hr-PcgDs"
    },
    "source": [
    "### What You Will Learn\n",
    "\n",
    "- How to write a basic neural network from scratch in PyTorch\n",
    "- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "6c7bFQ20LbLB"
    },
    "source": [
    "At its core, PyTorch is a library for\n",
    "- doing math on arrays\n",
    "- with automatic calculation of gradients\n",
    "- that is easy to accelerate with GPUs and distribute over nodes.\n",
    "\n",
    "Much of the time,\n",
    "we work at a remove from the core features of PyTorch,\n",
    "using abstractions from `torch.nn`\n",
    "or from frameworks on top of PyTorch.\n",
    "\n",
    "This tutorial builds those abstractions up\n",
    "from core PyTorch,\n",
    "showing how to go from basic iterated\n",
    "gradient computation and application\n",
    "to a solid training and validation loop.\n",
    "It is adapted from the PyTorch tutorial\n",
    "[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n",
    "\n",
    "We assume familiarity with the fundamentals of ML and DNNs here,\n",
    "like gradient-based optimization and statistical learning.\n",
    "For refreshing on those, we recommend\n",
    "[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n",
    "or\n",
    "[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "vs0LXXlCU6Ix"
    },
    "source": [
    "# Setup"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "ZkQiK7lkgeXm"
    },
    "source": [
    "If you're running this notebook on Google Colab,\n",
    "the cell below will run full environment setup.\n",
    "\n",
    "It should take about three minutes to run."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "sVx7C7H0PIZC"
    },
    "outputs": [],
    "source": [
    "lab_idx = 1\n",
    "\n",
    "if \"bootstrap\" not in locals() or bootstrap.run:\n",
    " # path management for Python\n",
    " pythonpath, = !echo $PYTHONPATH\n",
    " if \".\" not in pythonpath.split(\":\"):\n",
    " pythonpath = \".:\" + pythonpath\n",
    " %env PYTHONPATH={pythonpath}\n",
    " !echo $PYTHONPATH\n",
    "\n",
    " # get both Colab and local notebooks into the same state\n",
    " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
    " import bootstrap\n",
    "\n",
    " # change into the lab directory\n",
    " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
    "\n",
    " # allow \"hot-reloading\" of modules\n",
    " %load_ext autoreload\n",
    " %autoreload 2\n",
    " # needed for inline plots in some contexts\n",
    " %matplotlib inline\n",
    "\n",
    " bootstrap.run = False # change to True re-run setup\n",
    " \n",
    "!pwd\n",
    "%ls"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "6wJ8r7BTPB-t"
    },
    "source": [
    "# Getting data and making `Tensor`s"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "MpRyqPPYie-F"
    },
    "source": [
    "Before we can build a model,\n",
    "we need data.\n",
    "\n",
    "The code below uses the Python standard library to download the\n",
    "[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n",
    "from the internet.\n",
    "\n",
    "The data used to train state-of-the-art models these days\n",
    "is generally too large to be stored on the disk of any single machine\n",
    "(to say nothing of the RAM!),\n",
    "so fetching data over a network is a common first step in model training."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "CsokTZTMJ3x6"
    },
    "outputs": [],
    "source": [
    "from pathlib import Path\n",
    "import requests\n",
    "\n",
    "\n",
    "def download_mnist(path):\n",
    " url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n",
    " filename = \"mnist.pkl.gz\"\n",
    "\n",
    " if not (path / filename).exists():\n",
    " content = requests.get(url + filename).content\n",
    " (path / filename).open(\"wb\").write(content)\n",
    "\n",
    " return path / filename\n",
    "\n",
    "\n",
    "data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n",
    "path = data_path / \"downloaded\" / \"vector-mnist\"\n",
    "path.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "datafile = download_mnist(path)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "-S0es1DujOyr"
    },
    "source": [
    "Larger data consumes more resources --\n",
    "when reading, writing, and sending over the network --\n",
    "so the dataset is compressed\n",
    "(`.gz` extension).\n",
    "\n",
    "Each piece of the dataset\n",
    "(training and validation inputs and outputs)\n",
    "is a single Python object\n",
    "(specifically, an array).\n",
    "We can persist Python objects to disk\n",
    "(also known as \"serialization\")\n",
    "and load them back in\n",
    "(also known as \"deserialization\")\n",
    "using the `pickle` library\n",
    "(`.pkl` extension)."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "QZosCF1xJ3x7"
    },
    "outputs": [],
    "source": [
    "import gzip\n",
    "import pickle\n",
    "\n",
    "\n",
    "def read_mnist(path):\n",
    " with gzip.open(path, \"rb\") as f:\n",
    " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
    " return x_train, y_train, x_valid, y_valid\n",
    "\n",
    "x_train, y_train, x_valid, y_valid = read_mnist(datafile)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "KIYUbKgmknDf"
    },
    "source": [
    "PyTorch provides its own array type,\n",
    "the `torch.Tensor`.\n",
    "The cell below converts our arrays into `torch.Tensor`s.\n",
    "\n",
    "Very roughly speaking, a \"tensor\" in ML\n",
    "just means the same thing as an\n",
    "\"array\" elsewhere in computer science.\n",
    "Terminology is different in\n",
    "[physics](https://physics.stackexchange.com/a/270445),\n",
    "[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n",
    "and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n",
    "but here the term \"tensor\" is intended to connote\n",
    "an array that might have more than two dimensions."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "ea5d3Ggfkhea"
    },
    "outputs": [],
    "source": [
    "import torch\n",
    "\n",
    "\n",
    "x_train, y_train, x_valid, y_valid = map(\n",
    " torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
    ")"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "D0AMKLxGkmc_"
    },
    "source": [
    "Tensors are defined by their contents:\n",
    "they are big rectangular blocks of numbers."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "yPvh8c_pkl5A"
    },
    "outputs": [],
    "source": [
    "print(x_train, y_train, sep=\"\\n\")"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "4UOYvwjFqdzu"
    },
    "source": [
    "Accessing the contents of `Tensor`s is called \"indexing\",\n",
    "and uses the same syntax as general Python indexing.\n",
    "It always returns a new `Tensor`:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "9zGDAPXVqdCm"
    },
    "outputs": [],
    "source": [
    "y_train[0], x_train[0, ::2]"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "QhJcOr8TmgmQ"
    },
    "source": [
    "PyTorch, like many libraries for high-performance array math,\n",
    "allows us to quickly and easily access metadata about our tensors."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "4ENirftAnIVM"
    },
    "source": [
    "The most important pieces of metadata about a `Tensor`,\n",
    "or any array, are its _dimension_\n",
    "and its _shape_.\n",
    "\n",
    "The dimension specifies how many indices you need to get a number\n",
    "out of an array."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "mhaN6qW0nA5t"
    },
    "outputs": [],
    "source": [
    "x_train.ndim, y_train.ndim"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "9pYEk13yoGgz"
    },
    "outputs": [],
    "source": [
    "x_train[0, 0], y_train[0]"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "rv2WWNcHkEeS"
    },
    "source": [
    "For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n",
    "For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "yZ6j-IGPJ3x7"
    },
    "outputs": [],
    "source": [
    "n, c = x_train.shape\n",
    "print(x_train.shape)\n",
    "print(y_train.shape)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "H-HFN9WJo6FK"
    },
    "source": [
    "This metadata serves a similar purpose for `Tensor`s\n",
    "as type metadata serves for other objects in Python\n",
    "(and other programming languages).\n",
    "\n",
    "That is, types tell us whether an object is an acceptable\n",
    "input for or output of a function.\n",
    "Many functions on `Tensor`s, like indexing,\n",
    "matrix multiplication,\n",
    "can only accept as input `Tensor`s of a certain shape and dimension\n",
    "and will return as output `Tensor`s of a certain shape and dimension.\n",
    "\n",
    "So printing `ndim` and `shape` to track\n",
    "what's happening to `Tensor`s during a computation\n",
    "is an important piece of the debugging toolkit!"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "wCjuWKKNrWGM"
    },
    "source": [
    "We won't spend much time here on writing raw array math code in PyTorch,\n",
    "nor will we spend much time on how PyTorch works.\n",
    "\n",
    "> If you'd like to get better at writing PyTorch code,\n",
    "try out\n",
    "[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n",
    "We wrote a bit about what these puzzles reveal about programming\n",
    "with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n",
    "\n",
    "> If you'd like to get a better understanging of the internals\n",
    "of PyTorch, check out\n",
    "[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n",
    "\n",
    "As we'll see below,\n",
    "`torch.nn` provides most of what we need\n",
    "for building deep learning models."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "Li5e_jiJpLSI"
    },
    "source": [
    "The `Tensor`s inside of the `x_train` `Tensor`\n",
    "aren't just any old blocks of numbers:\n",
    "they're images of handwritten digits.\n",
    "The `y_train` `Tensor` contains the identities of those digits.\n",
    "\n",
    "Let's take a look at a random example:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "4VsHk6xNJ3x8"
    },
    "outputs": [],
    "source": [
    "# re-execute this cell for more samples\n",
    "import random\n",
    "\n",
    "import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n",
    "\n",
    "import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n",
    "\n",
    "idx = random.randint(0, len(x_train))\n",
    "example = x_train[idx]\n",
    "\n",
    "print(y_train[idx]) # the label of the image\n",
    "wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "PC3pwoJ9s-ts"
    },
    "source": [
    "We want to build a deep network that can take in an image\n",
    "and return the number that's in the image.\n",
    "\n",
    "We'll build that network\n",
    "by fitting it to `x_train` and `y_train`.\n",
    "\n",
    "We'll first do our fitting with just basic `torch` components and Python,\n",
    "then we'll add in other `torch` gadgets and goodies\n",
    "until we have a more realistic neural network fitting loop.\n",
    "\n",
    "Later in the labs,\n",
    "we'll see how to even more quickly build\n",
    "performant, robust fitting loops\n",
    "that have even more features\n",
    "by using libraries built on top of PyTorch."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "DTLdqCIGJ3x6"
    },
    "source": [
    "# Building a DNN using only `torch.Tensor` methods and Python"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "8D8Xuh2xui3o"
    },
    "source": [
    "One of the really great features of PyTorch\n",
    "is that writing code in PyTorch feels\n",
    "very similar to writing other code in Python --\n",
    "unlike other deep learning frameworks\n",
    "that can sometimes feel like their own language\n",
    "or programming paradigm.\n",
    "\n",
    "This fact can sometimes be obscured\n",
    "when you're using lots of library code,\n",
    "so we start off by just using `Tensor`s and the Python standard library."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "tOV0bxySJ3x9"
    },
    "source": [
    "## Defining the model"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "ZLH_zUWkw3W0"
    },
    "source": [
    "We'll make the simplest possible neural network:\n",
    "a single layer that performs matrix multiplication,\n",
    "and adds a vector of biases.\n",
    "\n",
    "We'll need values for the entries of the matrix,\n",
    "which we generate randomly.\n",
    "\n",
    "We also need to tell PyTorch that we'll\n",
    "be taking gradients with respect to\n",
    "these `Tensor`s later, so we use `requires_grad`."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "1c21c8XQJ3x-"
    },
    "outputs": [],
    "source": [
    "import math\n",
    "\n",
    "import torch\n",
    "\n",
    "\n",
    "weights = torch.randn(784, 10) / math.sqrt(784)\n",
    "weights.requires_grad_()\n",
    "bias = torch.zeros(10, requires_grad=True)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "GZC8A01sytm2"
    },
    "source": [
    "We can combine our beloved Python operators,\n",
    "like `+` and `*` and `@` and indexing,\n",
    "to define the model."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "8Eoymwooyq0-"
    },
    "outputs": [],
    "source": [
    "def linear(x: torch.Tensor) -> torch.Tensor:\n",
    " return x @ weights + bias"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "5tIRHR_HxeZf"
    },
    "source": [
    "We need to normalize our model's outputs with a `softmax`\n",
    "to get our model to output something we can use\n",
    "as a probability distribution --\n",
    "the probability that the network assigns to each label for the image.\n",
    "\n",
    "For that, we'll need some `torch` math functions,\n",
    "like `torch.sum` and `torch.exp`.\n",
    "\n",
    "We compute the logarithm of that softmax value\n",
    "in part for numerical stability reasons\n",
    "and in part because\n",
    "[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "WuZRGSr4J3x-"
    },
    "outputs": [],
    "source": [
    "def log_softmax(x: torch.Tensor) -> torch.Tensor:\n",
    " return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n",
    "\n",
    "def model(xb: torch.Tensor) -> torch.Tensor:\n",
    " return log_softmax(linear(xb))"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "-pBI4pOM011q"
    },
    "source": [
    "Typically, we split our dataset up into smaller \"batches\" of data\n",
    "and apply our model to one batch at a time.\n",
    "\n",
    "Since our dataset is just a `Tensor`,\n",
    "we can pull that off just with indexing:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "pXsHak23J3x_"
    },
    "outputs": [],
    "source": [
    "bs = 64 # batch size\n",
    "\n",
    "xb = x_train[0:bs] # a batch of inputs\n",
    "outs = model(xb) # outputs on that batch\n",
    "\n",
    "print(outs[0], outs.shape) # outputs on the first element of the batch"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "VPrG9x1DJ3x_"
    },
    "source": [
    "## Defining the loss and metrics"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "zEwPJmgZ1HIp"
    },
    "source": [
    "Our model produces outputs, but they are mostly wrong,\n",
    "since we set the weights randomly.\n",
    "\n",
    "How can we quantify just how wrong our model is,\n",
    "so that we can make it better?"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "JY-2QZEu1Xc7"
    },
    "source": [
    "We want to compare the outputs and the target labels,\n",
    "but the model outputs a probability distribution,\n",
    "and the labels are just numbers.\n",
    "\n",
    "We can take the label that had the highest probability\n",
    "(the index of the largest output for each input,\n",
    "aka the `argmax` over `dim`ension `1`)\n",
    "and treat that as the model's prediction\n",
    "for the digit in the image."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "_sHmDw_cJ3yC"
    },
    "outputs": [],
    "source": [
    "def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n",
    " preds = torch.argmax(out, dim=1)\n",
    " return (preds == yb).float().mean()"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "PfrDJb2EF_uz"
    },
    "source": [
    "If we run that function on our model's `out`put`s`,\n",
    "we can confirm that the random model isn't doing well --\n",
    "we expect to see that something around one in ten predictions are correct."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "8l3aRMNaJ3yD"
    },
    "outputs": [],
    "source": [
    "yb = y_train[0:bs]\n",
    "\n",
    "acc = accuracy(outs, yb)\n",
    "\n",
    "print(acc)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "fxRfO1HQ3VYs"
    },
    "source": [
    "We can calculate how good our network is doing,\n",
    "so are we ready to use optimization to make it do better?\n",
    "\n",
    "Not yet!\n",
    "To train neural networks, we use gradients\n",
    "(aka derivatives).\n",
    "So all of the functions we use need to be differentiable --\n",
    "in particuar they need to change smoothly so that a small change in input\n",
    "can only cause a small change in output.\n",
    "\n",
    "Our `argmax` breaks that rule\n",
    "(if the values at index `0` and index `N` are really close together,\n",
    "a tiny change can change the output by `N`)\n",
    "so we can't use it.\n",
    "\n",
    "If we try to run our `backward`s pass to get a gradient,\n",
    "we get a `RuntimeError`:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "g5AnK4md4kxv"
    },
    "outputs": [],
    "source": [
    "try:\n",
    " acc.backward()\n",
    "except RuntimeError as e:\n",
    " print(e)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "HJ4WWHHJ460I"
    },
    "source": [
    "So we'll need something else:\n",
    "a differentiable function that gets smaller when\n",
    "our model gets better, aka a `loss`.\n",
    "\n",
    "The typical choice is to maximize the\n",
    "probability the network assigns to the correct label.\n",
    "\n",
    "We could try doing that directly,\n",
    "but more generally,\n",
    "we want the model's output probability distribution\n",
    "to match what we provide it -- \n",
    "here, we claim we're 100% certain in every label,\n",
    "but in general we allow for uncertainty.\n",
    "We quantify that match with the\n",
    "[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n",
    "\n",
    "Cross entropies\n",
    "[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n",
    "including more familiar functions like the\n",
    "mean squared error and the mean absolute error.\n",
    "\n",
    "We can calculate it directly from the outputs and target labels\n",
    "using some cute tricks:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "-k20rW_rJ3yA"
    },
    "outputs": [],
    "source": [
    "def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
    " return -output[range(target.shape[0]), target].mean()\n",
    "\n",
    "loss_func = cross_entropy"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "YZa1DSGN7zPK"
    },
    "source": [
    "With random guessing on a dataset with 10 equally likely options,\n",
    "we expect our loss value to be close to the negative logarithm of 1/10:\n",
    "the amount of entropy in a uniformly random digit."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "1bKRJ90MJ3yB"
    },
    "outputs": [],
    "source": [
    "print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "hTgFTdVgAGJW"
    },
    "source": [
    "Now we can call `.backward` without PyTorch complaining:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "1LH_ZpY0_e_6"
    },
    "outputs": [],
    "source": [
    "loss = loss_func(outs, yb)\n",
    "\n",
    "loss.backward()"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "ji0FA3dDACUk"
    },
    "source": [
    "But wait, where are the gradients?\n",
    "They weren't returned by `loss` above,\n",
    "so where could they be?\n",
    "\n",
    "They've been stored in the `.grad` attribute\n",
    "of the parameters of our model,\n",
    "`weights` and `bias`:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "Zgtyyhp__s8a"
    },
    "outputs": [],
    "source": [
    "bias.grad"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "dWTYno0JJ3yD"
    },
    "source": [
    "## Defining and running the fitting loop"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "TTR2Qo9F8ZLQ"
    },
    "source": [
    "We now have all the ingredients we need to fit a neural network to data:\n",
    "- data (`x_train`, `y_train`)\n",
    "- a network architecture with parameters (`model`, `weights`, and `bias`)\n",
    "- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n",
    "\n",
    "We can put them together into a training loop\n",
    "just using normal Python features,\n",
    "like `for` loops, indexing, and function calls:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "SzNZVEiVJ3yE"
    },
    "outputs": [],
    "source": [
    "lr = 0.5 # learning rate hyperparameter\n",
    "epochs = 2 # how many epochs to train for\n",
    "\n",
    "for epoch in range(epochs): # loop over the data repeatedly\n",
    " for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n",
    " start_idx = ii * bs # we are ii batches in, each of size bs\n",
    " end_idx = start_idx + bs # and we want the next bs entires\n",
    "\n",
    " # pull batches from x and from y\n",
    " xb = x_train[start_idx:end_idx]\n",
    " yb = y_train[start_idx:end_idx]\n",
    "\n",
    " # run model\n",
    " pred = model(xb)\n",
    "\n",
    " # get loss\n",
    " loss = loss_func(pred, yb)\n",
    "\n",
    " # calculate the gradients with a backwards pass\n",
    " loss.backward()\n",
    "\n",
    " # update the parameters\n",
    " with torch.no_grad(): # we don't want to track gradients through this part!\n",
    " # SGD learning rule: update with negative gradient scaled by lr\n",
    " weights -= weights.grad * lr\n",
    " bias -= bias.grad * lr\n",
    "\n",
    " # ACHTUNG: PyTorch doesn't assume you're done with gradients\n",
    " # until you say so -- by explicitly \"deleting\" them,\n",
    " # i.e. setting the gradients to 0.\n",
    " weights.grad.zero_()\n",
    " bias.grad.zero_()"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "9J-BfH1e_Jkx"
    },
    "source": [
    "To check whether things are working,\n",
    "we confirm that the value of the `loss` has gone down\n",
    "and the `accuracy` has gone up:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "mHgGCLaVJ3yE"
    },
    "outputs": [],
    "source": [
    "print(loss_func(model(xb), yb), accuracy(model(xb), yb))"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "E1ymEPYdcRHO"
    },
    "source": [
    "We can also run the model on a few examples\n",
    "to get a sense for how it's doing --\n",
    "always good for detecting bugs in our evaluation metrics!"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "O88PWejlcSTL"
    },
    "outputs": [],
    "source": [
    "# re-execute this cell for more samples\n",
    "idx = random.randint(0, len(x_train))\n",
    "example = x_train[idx:idx+1]\n",
    "\n",
    "out = model(example)\n",
    "\n",
    "print(out.argmax())\n",
    "wandb.Image(example.reshape(28, 28)).image"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "7L1Gq1N_J3yE"
    },
    "source": [
    "# Refactoring with core `torch.nn` components"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "EE5nUXMG_Yry"
    },
    "source": [
    "This works!\n",
    "But it's rather tedious and manual --\n",
    "we have to track what the parameters of our model are,\n",
    "apply the parameter updates to each one individually ourselves,\n",
    "iterate over the dataset directly, etc.\n",
    "\n",
    "It's also very literal:\n",
    "many assumptions about our problem are hard-coded in the loop.\n",
    "If our dataset was, say, stored in CSV files\n",
    "and too large to fit in RAM,\n",
    "we'd have to rewrite most of our training code.\n",
    "\n",
    "For the next few sections,\n",
    "we'll progressively refactor this code to\n",
    "make it shorter, cleaner,\n",
    "and more extensible\n",
    "using tools from the sublibraries of PyTorch:\n",
    "`torch.nn`, `torch.optim`, and `torch.utils.data`."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "BHEixRsbJ3yF"
    },
    "source": [
    "## Using `torch.nn.functional` for stateless computation"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "9k94IlN58lWa"
    },
    "source": [
    "First, let's drop that `cross_entropy` and `log_softmax`\n",
    "we implemented ourselves --\n",
    "whenever you find yourself implementing basic mathematical operations\n",
    "in PyTorch code you want to put in production,\n",
    "take a second to check whether the code you need's not out\n",
    "there in a library somewhere.\n",
    "You'll get fewer bugs and faster code for less effort!"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "sP-giy1a9Ct4"
    },
    "source": [
    "Both of those functions operated on their inputs\n",
    "without reference to any global variables,\n",
    "so we find their implementation in `torch.nn.functional`,\n",
    "where stateless computations live."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "vfWyJW1sJ3yF"
    },
    "outputs": [],
    "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "loss_func = F.cross_entropy\n",
    "\n",
    "def model(xb):\n",
    " return xb @ weights + bias"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "kqYIkcvpJ3yF"
    },
    "outputs": [],
    "source": [
    "print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "vXFyM1tKJ3yF"
    },
    "source": [
    "## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "PInL-9sbCKnv"
    },
    "source": [
    "Perhaps the biggest issue with our setup is how we're handling state.\n",
    "\n",
    "The `model` function refers to two global variables: `weights` and `bias`.\n",
    "These variables are critical for it to run,\n",
    "but they are defined outside of the function\n",
    "and are manipulated willy-nilly by other operations.\n",
    "\n",
    "This problem arises because of a fundamental tension in\n",
    "deep neural networks.\n",
    "We want to use them _as functions_ --\n",
    "when the time comes to make predictions in production,\n",
    "we put inputs in and get outputs out,\n",
    "just like any other function.\n",
    "But neural networks are fundamentally stateful,\n",
    "because they are _parameterized_ functions,\n",
    "and fiddling with the values of those parameters\n",
    "is the purpose of optimization.\n",
    "\n",
    "PyTorch's solution to this is the `nn.Module` class:\n",
    "a Python class that is callable like a function\n",
    "but tracks state like an object.\n",
    "\n",
    "Whatever `Tensor`s representing state we want PyTorch\n",
    "to track for us inside of our model\n",
    "get defined as `nn.Parameter`s and attached to the model\n",
    "as attributes."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "A34hxhd0J3yF"
    },
    "outputs": [],
    "source": [
    "from torch import nn\n",
    "\n",
    "\n",
    "class MNISTLogistic(nn.Module):\n",
    " def __init__(self):\n",
    " super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n",
    " self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n",
    " self.bias = nn.Parameter(torch.zeros(10))"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "pFD_sIRaFbbx"
    },
    "source": [
    "We define the computation that uses that state\n",
    "in the `.forward` method.\n",
    "\n",
    "Using some behind-the-scenes magic,\n",
    "this method gets called if we treat\n",
    "the instantiated `nn.Module` like a function by\n",
    "passing it arguments.\n",
    "You can give similar special powers to your own classes\n",
    "by defining `__call__` \"magic dunder\" method\n",
    "on them.\n",
    "\n",
    "> <small> <small> We've separated the definition of the `.forward` method\n",
    "from the definition of the class above and\n",
    "attached the method to the class manually below.\n",
    "We only do this to make the construction of the class\n",
    "easier to read and understand in the context this notebook --\n",
    "a neat little trick we'll use a lot in these labs.\n",
    "Normally, we'd just define the `nn.Module` all at once.</small></small>"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "0QAKK3dlFT9w"
    },
    "outputs": [],
    "source": [
    "def forward(self, xb: torch.Tensor) -> torch.Tensor:\n",
    " return xb @ self.weights + self.bias\n",
    "\n",
    "MNISTLogistic.forward = forward\n",
    "\n",
    "model = MNISTLogistic() # instantiated as an object\n",
    "print(model(xb)[:4]) # callable like a function\n",
    "loss = loss_func(model(xb), yb) # composable like a function\n",
    "loss.backward() # we can still take gradients through it\n",
    "print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "r-Yy2eYTHMVl"
    },
    "source": [
    "But how do we apply our updates?\n",
    "Do we need to access `model.weights.grad` and `model.weights`,\n",
    "like we did in our first implementation?\n",
    "\n",
    "Luckily, we don't!\n",
    "We can iterate over all of our model's `torch.nn.Parameters`\n",
    "via the `.parameters` method:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "vM59vE-5JiXV"
    },
    "outputs": [],
    "source": [
    "print(*list(model.parameters()), sep=\"\\n\")"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "tbFCdWBkNft0"
    },
    "source": [
    "That means we no longer need to assume we know the names\n",
    "of the model's parameters when we do our update --\n",
    "we can reuse the same loop with different models."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "hA925fIUK0gg"
    },
    "source": [
    "Let's wrap all of that up into a single function to `fit` our model:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "q9NxJZTOJ3yG"
    },
    "outputs": [],
    "source": [
    "def fit():\n",
    " for epoch in range(epochs):\n",
    " for ii in range((n - 1) // bs + 1):\n",
    " start_idx = ii * bs\n",
    " end_idx = start_idx + bs\n",
    " xb = x_train[start_idx:end_idx]\n",
    " yb = y_train[start_idx:end_idx]\n",
    " pred = model(xb)\n",
    " loss = loss_func(pred, yb)\n",
    "\n",
    " loss.backward()\n",
    " with torch.no_grad():\n",
    " for p in model.parameters(): # finds params automatically\n",
    " p -= p.grad * lr\n",
    " model.zero_grad()\n",
    "\n",
    "fit()"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "Mjmsb94mK8po"
    },
    "source": [
    "and check that we didn't break anything,\n",
    "i.e. that our model still gets accuracy much higher than 10%:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "Vo65cLS5J3yH"
    },
    "outputs": [],
    "source": [
    "print(accuracy(model(xb), yb))"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "fxYq2sCLJ3yI"
    },
    "source": [
    "# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "95c67wZCMynl"
    },
    "source": [
    "Our model's state is being handled respectably,\n",
    "our fitting loop is 2x shorter,\n",
    "and we can train different models if we'd like.\n",
    "\n",
    "But we're not done yet!\n",
    "Many steps we're doing manually above\n",
    "are already built in to `torch`."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "CE2VFjDZJ3yI"
    },
    "source": [
    "## Using `torch.nn.Linear` for the model definition"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "Zvcnrz2uJ3yI"
    },
    "source": [
    "As with our hand-rolled `cross_entropy`\n",
    "that could be profitably replaced with\n",
    "the industrial grade `nn.functional.cross_entropy`,\n",
    "we should replace our bespoke linear layer\n",
    "with something made by experts.\n",
    "\n",
    "Instead of defining `nn.Parameters`,\n",
    "effectively raw `Tensor`s, as attributes\n",
    "of our `nn.Module`,\n",
    "we can define other `nn.Module`s as attributes.\n",
    "PyTorch assigns the `nn.Parameters`\n",
    "of any child `nn.Module`s to the parent, recursively.\n",
    "\n",
    "These `nn.Module`s are reusable --\n",
    "say, if we want to make a network with multiple layers of the same type --\n",
    "and there are lots of them already defined:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "l-EKdhXcPjq2"
    },
    "outputs": [],
    "source": [
    "import textwrap\n",
    "\n",
    "print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "KbIIQMaBQC45"
    },
    "source": [
    "We want the humble `nn.Linear`,\n",
    "which applies the same\n",
    "matrix multiplication and bias operation."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "JHwS-1-rJ3yJ"
    },
    "outputs": [],
    "source": [
    "class MNISTLogistic(nn.Module):\n",
    " def __init__(self):\n",
    " super().__init__()\n",
    " self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n",
    "\n",
    " def forward(self, xb):\n",
    " return self.lin(xb) # call nn.Linear.forward here"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "Mcb0UvcmJ3yJ"
    },
    "outputs": [],
    "source": [
    "model = MNISTLogistic()\n",
    "print(loss_func(model(xb), yb)) # loss is still close to 2.3"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "5hcjV8A2QjQJ"
    },
    "source": [
    "We can see that the `nn.Linear` module is a \"child\"\n",
    "of the `model`,\n",
    "and we don't see the matrix of weights and the bias vector:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "yKkU-GIPOQq4"
    },
    "outputs": [],
    "source": [
    "print(*list(model.children()))"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "kUdhpItWQui_"
    },
    "source": [
    "but if we ask for the model's `.parameters`,\n",
    "we find them:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "G1yGOj2LNDsS"
    },
    "outputs": [],
    "source": [
    "print(*list(model.parameters()), sep=\"\\n\")"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "DFlQyKl6J3yJ"
    },
    "source": [
    "## Applying gradients with `torch.optim.Optimizer`"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "IqImMaenJ3yJ"
    },
    "source": [
    "Applying gradients to optimize parameters\n",
    "and resetting those gradients to zero\n",
    "are very common operations.\n",
    "\n",
    "So why are we doing that by hand?\n",
    "Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n",
    "we don't have to --\n",
    "we just need to point a `torch.optim.Optimizer`\n",
    "at the parameters of our model.\n",
    "\n",
    "While we're at it, we can also use a more sophisticated optimizer --\n",
    "`Adam` is a common first choice."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "f5AUNLEKJ3yJ"
    },
    "outputs": [],
    "source": [
    "from torch import optim\n",
    "\n",
    "\n",
    "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
    " return optim.Adam(model.parameters(), lr=3e-4)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "jK9dy0sNJ3yK"
    },
    "outputs": [],
    "source": [
    "model = MNISTLogistic()\n",
    "opt = configure_optimizer(model)\n",
    "\n",
    "print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n",
    "\n",
    "for epoch in range(epochs):\n",
    " for ii in range((n - 1) // bs + 1):\n",
    " start_idx = ii * bs\n",
    " end_idx = start_idx + bs\n",
    " xb = x_train[start_idx:end_idx]\n",
    " yb = y_train[start_idx:end_idx]\n",
    " pred = model(xb)\n",
    " loss = loss_func(pred, yb)\n",
    "\n",
    " loss.backward()\n",
    " opt.step()\n",
    " opt.zero_grad()\n",
    "\n",
    "print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "4yk9re3HJ3yK"
    },
    "source": [
    "## Organizing data with `torch.utils.data.Dataset`"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "0ap3fcZpTIqJ"
    },
    "source": [
    "We're also manually handling the data.\n",
    "First, we're independently and manually aligning\n",
    "the inputs, `x_train`, and the outputs, `y_train`.\n",
    "\n",
    "Aligned data is important in ML.\n",
    "We want a way to combine multiple data sources together\n",
    "and index into them simultaneously.\n",
    "\n",
    "That's done with `torch.utils.data.Dataset`.\n",
    "Just inherit from it and implement two methods to support indexing:\n",
    "`__getitem__` and `__len__`."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "HPj25nkoVWRi"
    },
    "source": [
    "We'll cheat a bit here and pull in the `BaseDataset`\n",
    "class from the `text_recognizer` library,\n",
    "so that we can start getting some exposure\n",
    "to the codebase for the labs."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "NpltQ-4JJ3yK"
    },
    "outputs": [],
    "source": [
    "from text_recognizer.data.util import BaseDataset\n",
    "\n",
    "\n",
    "train_ds = BaseDataset(x_train, y_train)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "zV1bc4R5Vz0N"
    },
    "source": [
    "The cell below will pull up the documentation for this class,\n",
    "which effectively just indexes into the two `Tensor`s simultaneously.\n",
    "\n",
    "It can also apply transformations to the inputs and targets.\n",
    "We'll see that later."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "XUWJ8yIWU28G"
    },
    "outputs": [],
    "source": [
    "BaseDataset??"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "zMQDHJNzWMtf"
    },
    "source": [
    "This makes our code a tiny bit cleaner:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "6iyqG4kEJ3yK"
    },
    "outputs": [],
    "source": [
    "model = MNISTLogistic()\n",
    "opt = configure_optimizer(model)\n",
    "\n",
    "\n",
    "for epoch in range(epochs):\n",
    " for ii in range((n - 1) // bs + 1):\n",
    " xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n",
    " pred = model(xb)\n",
    " loss = loss_func(pred, yb)\n",
    "\n",
    " loss.backward()\n",
    " opt.step()\n",
    " opt.zero_grad()\n",
    "\n",
    "print(loss_func(model(xb), yb))"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "pTtRPp_iJ3yL"
    },
    "source": [
    "## Batching up data with `torch.utils.data.DataLoader`"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "FPnaMyokWSWv"
    },
    "source": [
    "We're also still manually building our batches.\n",
    "\n",
    "Making batches out of datasets is a core component of contemporary deep learning training workflows,\n",
    "so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n",
    "\n",
    "We just need to hand our `Dataset` to the `DataLoader`\n",
    "and choose a `batch_size`.\n",
    "\n",
    "We can tune that parameter and other `DataLoader` arguments,\n",
    "like `num_workers` and `pin_memory`,\n",
    "to improve the performance of our training loop.\n",
    "For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n",
    "[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "aqXX7JGCJ3yL"
    },
    "outputs": [],
    "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "train_ds = BaseDataset(x_train, y_train)\n",
    "train_dataloader = DataLoader(train_ds, batch_size=bs)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "iWry2CakJ3yL"
    },
    "outputs": [],
    "source": [
    "def fit(self: nn.Module, train_dataloader: DataLoader):\n",
    " opt = configure_optimizer(self)\n",
    "\n",
    " for epoch in range(epochs):\n",
    " for xb, yb in train_dataloader:\n",
    " pred = self(xb)\n",
    " loss = loss_func(pred, yb)\n",
    "\n",
    " loss.backward()\n",
    " opt.step()\n",
    " opt.zero_grad()\n",
    "\n",
    "MNISTLogistic.fit = fit"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "9pfdSJBIXT8o"
    },
    "outputs": [],
    "source": [
    "model = MNISTLogistic()\n",
    "\n",
    "model.fit(train_dataloader)\n",
    "\n",
    "print(loss_func(model(xb), yb))"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "RAs8-3IfJ3yL"
    },
    "source": [
    "Compare the ten line `fit` function with our first training loop (reproduced below) --\n",
    "much cleaner _and_ much more powerful!"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "_a51dZrLJ3yL"
    },
    "source": [
    "```python\n",
    "lr = 0.5 # learning rate\n",
    "epochs = 2 # how many epochs to train for\n",
    "\n",
    "for epoch in range(epochs):\n",
    " for ii in range((n - 1) // bs + 1):\n",
    " start_idx = ii * bs\n",
    " end_idx = start_idx + bs\n",
    " xb = x_train[start_idx:end_idx]\n",
    " yb = y_train[start_idx:end_idx]\n",
    " pred = model(xb)\n",
    " loss = loss_func(pred, yb)\n",
    "\n",
    " loss.backward()\n",
    " with torch.no_grad():\n",
    " weights -= weights.grad * lr\n",
    " bias -= bias.grad * lr\n",
    " weights.grad.zero_()\n",
    " bias.grad.zero_()\n",
    "```"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "jiQe3SEWyZo4"
    },
    "source": [
    "## Swapping in another model"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "KykHpZEWyZo4"
    },
    "source": [
    "To see that our new `.fit` is more powerful,\n",
    "let's use it with a different model.\n",
    "\n",
    "Specifically, let's draw in the `MLP`,\n",
    "or \"multi-layer perceptron\" model\n",
    "from the `text_recognizer` library\n",
    "in our codebase."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "1FtGJg1CyZo4"
    },
    "outputs": [],
    "source": [
    "from text_recognizer.models.mlp import MLP\n",
    "\n",
    "\n",
    "MLP.fit = fit # attach our fitting loop"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "kJiP3a-8yZo4"
    },
    "source": [
    "If you look in the `.forward` method of the `MLP`,\n",
    "you'll see that it uses\n",
    "some modules and functions we haven't seen, like\n",
    "[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
    "and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n",
    "but otherwise fits the interface of our training loop:\n",
    "the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "hj-0UdJwyZo4"
    },
    "outputs": [],
    "source": [
    "MLP.forward??"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "FS7dxQ4VyZo4"
    },
    "source": [
    "If we look at the constructor, `__init__`,\n",
    "we see that the `nn.Module`s (`fc` and `dropout`)\n",
    "are initialized and attached as attributes."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "x0NpkeA8yZo5"
    },
    "outputs": [],
    "source": [
    "MLP.__init__??"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "Uygy5HsUyZo5"
    },
    "source": [
    "We also see that we are required to provide a `data_config`\n",
    "dictionary and can optionally configure the module with `args`.\n",
    "\n",
    "For now, we'll only do the bare minimum and specify\n",
    "the contents of the `data_config`:\n",
    "the `input_dims` for `x` and the `mapping`\n",
    "from class index in `y` to class label,\n",
    "which we can see are used in the `__init__` method."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "y6BEl_I-yZo5"
    },
    "outputs": [],
    "source": [
    "digits_to_9 = list(range(10))\n",
    "data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n",
    "data_config"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "bEuNc38JyZo5"
    },
    "outputs": [],
    "source": [
    "model = MLP(data_config)\n",
    "model"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "CWQK2DWWyZo6"
    },
    "source": [
    "The resulting `MLP` is a bit larger than our `MNISTLogistic` model:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "zs1s6ahUyZo8"
    },
    "outputs": [],
    "source": [
    "model.fc1.weight"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "JVLkK78FyZo8"
    },
    "source": [
    "But that doesn't matter for our fitting loop,\n",
    "which happily optimizes this model on batches from the `train_dataloader`,\n",
    "though it takes a bit longer."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "Y-DItXLoyZo9"
    },
    "outputs": [],
    "source": [
    "%%time\n",
    "\n",
    "print(\"before training:\", loss_func(model(xb), yb))\n",
    "\n",
    "train_ds = BaseDataset(x_train, y_train)\n",
    "train_dataloader = DataLoader(train_ds, batch_size=bs)\n",
    "fit(model, train_dataloader)\n",
    "\n",
    "print(\"after training:\", loss_func(model(xb), yb))"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "9QgTv2yzJ3yM"
    },
    "source": [
    "# Extra goodies: data organization, validation, and acceleration"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "Vx-CcCesbmyw"
    },
    "source": [
    "Before we've got a DNN fitting loop that's welcome in polite company,\n",
    "we need three more features:\n",
    "organized data loading code, validation, and GPU acceleration."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "8LWja5aDJ3yN"
    },
    "source": [
    "## Making the GPU go brrrrr"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "7juxQ_Kp-Tx0"
    },
    "source": [
    "Everything we've done so far has been on\n",
    "the central processing unit of the computer, or CPU.\n",
    "When programming in Python,\n",
    "it is on the CPU that\n",
    "almost all of our code becomes concrete instructions\n",
    "that cause a machine move around electrons."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "R25L3z8eAWIO"
    },
    "source": [
    "That's okay for small-to-medium neural networks,\n",
    "but computation quickly becomes a bottleneck that makes achieving\n",
    "good performance infeasible.\n",
    "\n",
    "In general, the problem of CPUs,\n",
    "which are general purpose computing devices,\n",
    "being too slow is solved by using more specialized accelerator chips --\n",
    "in the extreme case, application-specific integrated circuits (ASICs)\n",
    "that can only perform a single task,\n",
    "the hardware equivalents of\n",
    "[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n",
    "[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n",
    "\n",
    "Luckily, really excellent chips\n",
    "for accelerating deep learning are readily available\n",
    "as a consumer product:\n",
    "graphics processing units (GPUs),\n",
    "which are designed to perform large matrix multiplications in parallel.\n",
    "Their name derives from their origins\n",
    "applying large matrix multiplications to manipulate shapes and textures\n",
    "in for graphics engines for video games and CGI.\n",
    "\n",
    "If your system has a GPU and the right libraries installed\n",
    "for `torch` compatibility,\n",
    "the cell below will print information about its state."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "Xxy-Gt9wJ3yN"
    },
    "outputs": [],
    "source": [
    "if torch.cuda.is_available():\n",
    " !nvidia-smi\n",
    "else:\n",
    " print(\"☹️\")"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "x6qAX1OECiWk"
    },
    "source": [
    "PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n",
    "even simultaneously, which can be critical for high performance.\n",
    "\n",
    "So once we start using acceleration, we need to be more precise about where the\n",
    "data inside our `Tensor`s lives --\n",
    "on which physical `torch.device` it can be found.\n",
    "\n",
    "On compatible systems, the cell below will\n",
    "move all of the model's parameters `.to` the GPU\n",
    "(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n",
    "and then move a batch of inputs and targets there as well\n",
    "before applying the model and calculating the loss.\n",
    "\n",
    "To confirm this worked, look for the name of the device in the output of the cell,\n",
    "alongside other information about the loss `Tensor`."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "jGkpfEmbJ3yN"
    },
    "outputs": [],
    "source": [
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "\n",
    "model.to(device)\n",
    "\n",
    "loss_func(model(xb.to(device)), yb.to(device))"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "-zdPR06eDjIX"
    },
    "source": [
    "Rather than rewrite our entire `.fit` function,\n",
    "we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n",
    "\n",
    "Specifically,\n",
    "we can provide a `transform` that is called on the inputs\n",
    "and a `target_transform` that is called on the labels\n",
    "before they are returned.\n",
    "In the FSDL codebase,\n",
    "this feature is used for data preparation, like\n",
    "reshaping, resizing,\n",
    "and normalization.\n",
    "\n",
    "We'll use this as an opportunity to put the `Tensor`s on the appropriate device."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "m8WQS9Zo_Did"
    },
    "outputs": [],
    "source": [
    "def push_to_device(tensor):\n",
    " return tensor.to(device)\n",
    "\n",
    "train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
    "train_dataloader = DataLoader(train_ds, batch_size=bs)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "nmg9HMSZFmqR"
    },
    "source": [
    "We don't need to change anything about our fitting code to run it on the GPU!\n",
    "\n",
    "Note: given the small size of this model and the data,\n",
    "the speedup here can sometimes be fairly moderate (like 2x).\n",
    "For larger models, GPU acceleration can easily lead to 50-100x faster iterations."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "v1TVc06NkXrU"
    },
    "outputs": [],
    "source": [
    "%%time\n",
    "\n",
    "model = MLP(data_config)\n",
    "model.to(device)\n",
    "\n",
    "model.fit(train_dataloader)\n",
    "\n",
    "print(loss_func(model(push_to_device(xb)), push_to_device(yb)))"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "L7thbdjKTjAD"
    },
    "source": [
    "Writing high performance GPU-accelerated neural network code is challenging.\n",
    "There are many sharp edges, so the default\n",
    "strategy is imitation (basing all work on existing verified quality code)\n",
    "and conservatism bordering on paranoia about change.\n",
    "For a casual introduction to some of the core principles, see\n",
    "[Horace He's blogpost](https://horace.io/brrr_intro.html)."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "LnpbEVE5J3yM"
    },
    "source": [
    "## Adding validation data and organizing data code with a `DataModule`"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "EqYHjiG8b_4J"
    },
    "source": [
    "Just doing well on data you've seen before is not that impressive --\n",
    "the network could just memorize the label for each input digit.\n",
    "\n",
    "We need to check performance on a set of data points that weren't used\n",
    "directly to optimize the model,\n",
    "commonly called the validation set."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "7e6z-Fh8dOnN"
    },
    "source": [
    "We already downloaded one up above,\n",
    "but that was all the way at the beginning of the notebook,\n",
    "and I've already forgotten about it.\n",
    "\n",
    "In general, it's easy for data-loading code,\n",
    "the redheaded stepchild of the ML codebase,\n",
    "to become messy and fall out of sync.\n",
    "\n",
    "A proper `DataModule` collects up all of the code required\n",
    "to prepare data on a machine,\n",
    "sets it up as a collection of `Dataset`s,\n",
    "and turns those `Dataset`s into `DataLoader`s,\n",
    "as below:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "0WxgRa2GJ3yM"
    },
    "outputs": [],
    "source": [
    "class MNISTDataModule:\n",
    " url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n",
    " filename = \"mnist.pkl.gz\"\n",
    " \n",
    " def __init__(self, dir, bs=32):\n",
    " self.dir = dir\n",
    " self.bs = bs\n",
    " self.path = self.dir / self.filename\n",
    "\n",
    " def prepare_data(self):\n",
    " if not (self.path).exists():\n",
    " content = requests.get(self.url + self.filename).content\n",
    " self.path.open(\"wb\").write(content)\n",
    "\n",
    " def setup(self):\n",
    " with gzip.open(self.path, \"rb\") as f:\n",
    " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
    "\n",
    " x_train, y_train, x_valid, y_valid = map(\n",
    " torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
    " )\n",
    " \n",
    " self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
    " self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n",
    "\n",
    " def train_dataloader(self):\n",
    " return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n",
    " \n",
    " def val_dataloader(self):\n",
    " return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "x-8T_MlWifMe"
    },
    "source": [
    "We'll cover `DataModule`s in more detail later.\n",
    "\n",
    "We can now incorporate our `DataModule`\n",
    "into the fitting pipeline\n",
    "by calling its methods as needed:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "mcFcbRhSJ3yN"
    },
    "outputs": [],
    "source": [
    "def fit(self: nn.Module, datamodule):\n",
    " datamodule.prepare_data()\n",
    " datamodule.setup()\n",
    "\n",
    " val_dataloader = datamodule.val_dataloader()\n",
    " \n",
    " self.eval()\n",
    " with torch.no_grad():\n",
    " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
    "\n",
    " print(\"before start of training:\", valid_loss / len(val_dataloader))\n",
    "\n",
    " opt = configure_optimizer(self)\n",
    " train_dataloader = datamodule.train_dataloader()\n",
    " for epoch in range(epochs):\n",
    " self.train()\n",
    " for xb, yb in train_dataloader:\n",
    " pred = self(xb)\n",
    " loss = loss_func(pred, yb)\n",
    "\n",
    " loss.backward()\n",
    " opt.step()\n",
    " opt.zero_grad()\n",
    "\n",
    " self.eval()\n",
    " with torch.no_grad():\n",
    " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
    "\n",
    " print(epoch, valid_loss / len(val_dataloader))\n",
    "\n",
    "\n",
    "MNISTLogistic.fit = fit\n",
    "MLP.fit = fit"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "-Uqey9w6jkv9"
    },
    "source": [
    "Now we've substantially cut down on the \"hidden state\" in our fitting code:\n",
    "if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n",
    "then you can train a network with just the cell below."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "uxN1yV6DX6Nz"
    },
    "outputs": [],
    "source": [
    "model = MLP(data_config)\n",
    "model.to(device)\n",
    "\n",
    "datamodule = MNISTDataModule(dir=path, bs=32)\n",
    "\n",
    "model.fit(datamodule=datamodule)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "2zHA12Iih0ML"
    },
    "source": [
    "You may have noticed a few other changes in the `.fit` method:\n",
    "\n",
    "- `self.eval` vs `self.train`:\n",
    "it's helpful to have features of neural networks that behave differently in `train`ing\n",
    "than they do in production or `eval`uation.\n",
    "[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
    "and\n",
    "[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n",
    "are among the most popular examples.\n",
    "We need to take this into account now that we\n",
    "have a validation loop.\n",
    "- The return of `torch.no_grad`: in our first few implementations,\n",
    "we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n",
    "Now, we need to use it to avoid tracking gradients during validation."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "BaODkqTnJ3yO"
    },
    "source": [
    "This is starting to get a bit hairy again!\n",
    "We're back up to about 30 lines of code,\n",
    "right where we started\n",
    "(but now with way more features!).\n",
    "\n",
    "Much like `torch.nn` provides useful tools and interfaces for\n",
    "defining neural networks,\n",
    "iterating over batches,\n",
    "and calculating gradients,\n",
    "frameworks on top of PyTorch, like\n",
    "[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n",
    "provide useful tools and interfaces\n",
    "for an even higher level of abstraction over neural network training.\n",
    "\n",
    "For serious deep learning codebases,\n",
    "you'll want to use a framework at that level of abstraction --\n",
    "either one of the popular open frameworks or one developed in-house.\n",
    "\n",
    "For most of these frameworks,\n",
    "you'll still need facility with core PyTorch:\n",
    "at least for defining models and\n",
    "often for defining data pipelines as well."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "-4piIilkyZpD"
    },
    "source": [
    "# Exercises"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "E482VfIlyZpD"
    },
    "source": [
    "### 🌟 Try out different hyperparameters for the `MLP` and for training."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "IQ8bkAxNyZpD"
    },
    "source": [
    "The `MLP` class is configured via the `args` argument to its constructor,\n",
    "which can set the values of hyperparameters like the width of layers and the degree of dropout:"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "3Tl-AvMVyZpD"
    },
    "outputs": [],
    "source": [
    "MLP.__init__??"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "0HfbQ0KkyZpD"
    },
    "source": [
    "As the type signature indicates, `args` is an `argparse.Namespace`.\n",
    "[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n",
    "and later on we'll see how to configure models\n",
    "and launch training jobs from the command line\n",
    "in the FSDL codebase.\n",
    "\n",
    "For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n",
    "\n",
    "Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n",
    "\n",
    "Can you get a final `valid`ation `acc`uracy of 98%?\n",
    "Can you get to 95% 2x faster than the baseline `MLP`?"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "-vVtGJhtyZpD"
    },
    "outputs": [],
    "source": [
    "%%time \n",
    "from argparse import Namespace # you'll need this\n",
    "\n",
    "args = None # edit this\n",
    "\n",
    "epochs = 2 # used in fit\n",
    "bs = 32 # used by the DataModule\n",
    "\n",
    "\n",
    "# used in fit, play around with this if you'd like\n",
    "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
    " return optim.Adam(model.parameters(), lr=3e-4)\n",
    "\n",
    "\n",
    "model = MLP(data_config, args=args)\n",
    "model.to(device)\n",
    "\n",
    "datamodule = MNISTDataModule(dir=path, bs=bs)\n",
    "\n",
    "model.fit(datamodule=datamodule)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "7yyxc3uxyZpD"
    },
    "outputs": [],
    "source": [
    "val_dataloader = datamodule.val_dataloader()\n",
    "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
    "valid_acc"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "0ZHygZtgyZpE"
    },
    "source": [
    "### 🌟🌟🌟 Write your own `nn.Module`."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "r3Iu73j3yZpE"
    },
    "source": [
    "Designing new models is one of the most fun\n",
    "aspects of building an ML-powered application.\n",
    "\n",
    "Can you make an `nn.Module` that looks different from\n",
    "the standard `MLP` but still gets 98% validation accuracy or higher?\n",
    "You might start from the `MLP` and\n",
    "[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n",
    "while adding more bells and whistles.\n",
    "Take care to keep the shapes of the `Tensor`s aligned as you go.\n",
    "\n",
    "Here's some tricks you can try that are especially helpful with deeper networks:\n",
    "- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n",
    "layers, which can improve\n",
    "[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n",
    "- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n",
    "- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n",
    "like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n",
    "or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n",
    "\n",
    "If you want to make an `nn.Module` that can have different depths,\n",
    "check out the\n",
    "[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "JsF_RfrDyZpE"
    },
    "outputs": [],
    "source": [
    "class YourModel(nn.Module):\n",
    " def __init__(self): # add args and kwargs here as you like\n",
    " super().__init__()\n",
    " # use those args and kwargs to set up the submodules\n",
    " self.ps = nn.Parameter(torch.zeros(10))\n",
    "\n",
    " def forward(self, xb): # overwrite this to use your nn.Modules from above\n",
    " xb = torch.stack([self.ps for ii in range(len(xb))])\n",
    " return xb\n",
    " \n",
    " \n",
    "YourModel.fit = fit # don't forget this!"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "t6OQidtGyZpE"
    },
    "outputs": [],
    "source": [
    "model = YourModel()\n",
    "model.to(device)\n",
    "\n",
    "datamodule = MNISTDataModule(dir=path, bs=bs)\n",
    "\n",
    "model.fit(datamodule=datamodule)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "CH0U4ODoyZpE"
    },
    "outputs": [],
    "source": [
    "val_dataloader = datamodule.val_dataloader()\n",
    "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
    "valid_acc"
    ]
    }
    ],
    "metadata": {
    "accelerator": "GPU",
    "colab": {
    "name": "lab01_pytorch.ipynb",
    "private_outputs": true,
    "provenance": [],
    "toc_visible": true
    },
    "gpuClass": "standard",
    "kernelspec": {
    "display_name": "Python 3",
    "language": "python",
    "name": "python3"
    },
    "language_info": {
    "codemirror_mode": {
    "name": "ipython",
    "version": 3
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.7.13"
    }
    },
    "nbformat": 4,
    "nbformat_minor": 0
    }