{ "cells": [ { "metadata": {}, "cell_type": "markdown", "source": "This notebook breaks down how `cross_entropy` function (corresponding to `CrossEntropyLoss` used for classification) is implemented in pytorch, and how it is related to softmax, log_softmax, and nll (negative log-likelihood)." }, { "metadata": { "trusted": false }, "cell_type": "code", "source": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F", "execution_count": 82, "outputs": [] }, { "metadata": { "trusted": false }, "cell_type": "code", "source": "batch_size, n_classes = 5, 3\nx = torch.randn(batch_size, n_classes)\nx.shape", "execution_count": 83, "outputs": [ { "data": { "text/plain": "torch.Size([5, 3])" }, "execution_count": 83, "metadata": {}, "output_type": "execute_result" } ] }, { "metadata": { "trusted": false }, "cell_type": "code", "source": "x", "execution_count": 84, "outputs": [ { "data": { "text/plain": "tensor([[ 0.9826, 1.0630, -0.4096],\n [-0.6213, 0.2511, 0.5659],\n [ 0.5662, 0.7360, -0.6783],\n [-0.4638, -1.4961, -1.0877],\n [ 1.8186, -0.2998, 0.1128]])" }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ] }, { "metadata": { "trusted": false }, "cell_type": "code", "source": "target = torch.randint(n_classes, size=(batch_size,), dtype=torch.long)\ntarget", "execution_count": 85, "outputs": [ { "data": { "text/plain": "tensor([1, 0, 1, 1, 1])" }, "execution_count": 85, "metadata": {}, "output_type": "execute_result" } ] }, { "metadata": {}, "cell_type": "markdown", "source": "### `softmax` + `nl` (negative likelihood)" }, { "metadata": {}, "cell_type": "markdown", "source": "This version is most similar to the math formula, but not numerically stable." }, { "metadata": { "trusted": false }, "cell_type": "code", "source": "def softmax(x): return x.exp() / (x.exp().sum(-1)).unsqueeze(-1)\ndef nl(input, target): return -input[range(target.shape[0]), target].log().mean()\n\npred = softmax(x)\nloss=nl(pred, target)\nloss", "execution_count": 86, "outputs": [ { "data": { "text/plain": "tensor(1.4904)" }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ] }, { "metadata": { "trusted": false }, "cell_type": "code", "source": "pred = softmax(x)\nloss=nl(pred, target)\nloss", "execution_count": 87, "outputs": [ { "data": { "text/plain": "tensor(1.4904)" }, "execution_count": 87, "metadata": {}, "output_type": "execute_result" } ] }, { "metadata": {}, "cell_type": "markdown", "source": "### `log_softmax` + `nll` (negative log-likelihood)" }, { "metadata": {}, "cell_type": "markdown", "source": "https://pytorch.org/docs/stable/nn.html?highlight=logsoftmax#torch-nn-functional\n>While mathematically equivalent to `log(softmax(x))`, doing these two operations separately is slower, and numerically unstable. This function uses an alternative formulation to compute the output and gradient correctly." }, { "metadata": { "trusted": false }, "cell_type": "code", "source": "def log_softmax(x): return x - x.exp().sum(-1).log().unsqueeze(-1)\ndef nll(input, target): return -input[range(target.shape[0]), target].mean()\n\npred = log_softmax(x)\nloss = nll(pred, target)\nloss", "execution_count": 88, "outputs": [ { "data": { "text/plain": "tensor(1.4904)" }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ] }, { "metadata": {}, "cell_type": "markdown", "source": "### `F.log_softmax` + `F.nll_loss`" }, { "metadata": {}, "cell_type": "markdown", "source": "The above but in pytorch." }, { "metadata": { "trusted": false }, "cell_type": "code", "source": "pred = F.log_softmax(x, dim=-1)\nloss = F.nll_loss(pred, target)\nloss", "execution_count": 89, "outputs": [ { "data": { "text/plain": "tensor(1.4904)" }, "execution_count": 89, "metadata": {}, "output_type": "execute_result" } ] }, { "metadata": {}, "cell_type": "markdown", "source": "### `F.cross_entropy`" }, { "metadata": {}, "cell_type": "markdown", "source": "Pytorch's single cross_entropy function." }, { "metadata": { "trusted": false }, "cell_type": "code", "source": "F.cross_entropy(x, target)", "execution_count": 90, "outputs": [ { "data": { "text/plain": "tensor(1.4904)" }, "execution_count": 90, "metadata": {}, "output_type": "execute_result" } ] }, { "metadata": {}, "cell_type": "markdown", "source": "Reference:\n- https://github.com/fastai/fastai_old" }, { "metadata": { "trusted": false }, "cell_type": "code", "source": "", "execution_count": null, "outputs": [] } ], "metadata": { "_draft": { "nbviewer_url": "https://gist.github.com/217dcc6ae9171d7a46ce42e215c1fee0" }, "gist": { "id": "217dcc6ae9171d7a46ce42e215c1fee0", "data": { "description": "Cross entropy implementation in pytorch", "public": true } }, "kernelspec": { "name": "conda-env-fastaiv1-py", "display_name": "Python [conda env:fastaiv1]", "language": "python" }, "language_info": { "name": "python", "version": "3.7.0", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "base_numbering": 1, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }