{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "s_jAcr48kgIj" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "# if torch.backends.mps.is_available(): device = torch.device(\"mps\") # special for Mac" ] }, { "cell_type": "markdown", "metadata": { "id": "jjO62A2ikgIh" }, "source": [ "### Multi-head Attention from Scratch\n", "\n", "This notebook implements from scratch, in a step-by-step fashion, a multi-head self-attention layer, which gives the same output as the Pytorch implementation.\n", "\n", "References:\n", "- Original transformer paper: [Attention is all you need](https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)\n", "- PyTorch implementation is in the function [`multi_head_attention_forward`](\n", "https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py)\n", "\n", "by [Zaccharie Ramzi](https://zaccharieramzi.fr/) and [Gabriel Peyré](http://www.gpeyre.com/)." ] }, { "cell_type": "markdown", "metadata": { "id": "fhVrcbzwkgIk" }, "source": [ "The data is composed of batches of $p$ points $(x_s^b)_{s=0}^{p-1}$ in $\\mathbb{R}^d$ stored in a matrix `X` of size $(n_b,p,d)$. Here $n_b$ is the number of batches, so that $b$ runs in $0 \\ldots n_b-1$. These $n_b$ batches are processed in parallel. Note that we use here the \"batch first\" format." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "54FOnOPIkgIk" }, "outputs": [], "source": [ "n_b = 8 # size of batch, processed in parallel\n", "p = 80 # number of points in each points cloud\n", "d = 12 # dimension of the points\n", "X = torch.randn(n_b, p, d, device=device)" ] }, { "cell_type": "markdown", "metadata": { "id": "kPNMgSPzkgIl" }, "source": [ "Generate the parameter of the attention layer." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hxP1eBy1kgIl" }, "outputs": [], "source": [ "K = torch.randn(d, d, device=device)\n", "Q = torch.randn(d, d, device=device)\n", "V = torch.randn(d, d, device=device)\n", "L = torch.randn(d, d, device=device)" ] }, { "cell_type": "markdown", "metadata": { "id": "Tkp5_aamkgIl" }, "source": [ "$\\newcommand{\\coloneqq}{:=}$\n", "First the points are transformed in Keys, Queries, Values using matrices $K \\in \\mathbb{R}^{d \\times d}$, $Q \\in \\mathbb{R}^{d \\times d}$, $V \\in \\mathbb{R}^{d \\times d}$\n", "$$\n", " \\forall s = 0,\\ldots,p-1, \\quad\n", " k_s^b \\coloneqq K x_i^b, \\quad\n", " q_s^b \\coloneqq Q x_i^b, \\quad\n", " v_s^b \\coloneqq V x_i^b.\n", "$$\n", "These points are stored in the arrays `KX,QX,VX` of size $(n_b,p,d)$.\n", "\n", "We use Einstein summation notations to compute the transform, this is very useful and should be prefered over direct array manipulation (transposition, etc)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "background_save": true }, "id": "zlY6K5kPkgIm" }, "outputs": [], "source": [ "KX = torch.einsum(\"ij,bsj->bsi\", [K, X])\n", "QX = torch.einsum(\"ij,bsj->bsi\", [Q, X])\n", "VX = torch.einsum(\"ij,bsj->bsi\", [V, X])" ] }, { "cell_type": "markdown", "metadata": { "id": "4POiS_7mkgIm" }, "source": [ "Then each of these points such as $k_s^b \\in \\mathbb{R}^{d}$ are split into $n_h$ (\"number of heads\") points $k_{s}^{b,h} \\in \\mathbb{R}^{d_h}$ where $d_h \\coloneqq d/n_h$, i.e.\n", "$$\n", " k_s^b = (k_{s}^{b,0},\\ldots,k_{s}^{b,n_h-1}), \\quad\n", " q_s^b = (k_{s}^{b,0},\\ldots,k_{s}^{b,n_h-1}), \\quad\n", " v_s^b = (k_{s}^{b,0},\\ldots,k_{s}^{b,n_h-1}).\n", "$$\n", "These new points are still stored in the same `KX,QX,VX`, but they have size $(n_b,p,n_h,d_h)$." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uWEbCNzZkgIn" }, "outputs": [], "source": [ "n_h = 2 # number of heads\n", "d_h = d // n_h # dimension of each head\n", "KX = KX.reshape(n_b, p, n_h, d_h )\n", "QX = QX.reshape(n_b, p, n_h, d_h )\n", "VX = VX.reshape(n_b, p, n_h, d_h )" ] }, { "cell_type": "markdown", "metadata": { "id": "Ew5nPSM3kgIn" }, "source": [ "We then compute, for each head $h=0,\\ldots,n_h-1$, the inner products between the keys and queries\n", "$$\n", " \\forall (s,t) \\in \\{0,\\ldots,p-1\\}^2, \\quad\n", " D_{s,t}^{b,h} \\coloneqq \\langle k_{s}^{b,h}, q_{t}^{b,h} \\rangle_{\\mathbb{R}^{d_h}}\n", "$$\n", "and they are stored in the matrix `D` of size $(n_b,n_h,p,p)$." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HCSTIpqUkgIn" }, "outputs": [], "source": [ "D = torch.einsum(\"bshi,bthi->bhst\", [QX, KX])" ] }, { "cell_type": "markdown", "metadata": { "id": "ukk_wbOMkgIn" }, "source": [ "From these, one compute the attention kernel $U$ and row-normalize it to obtain $\\tilde U$ stored in `Ut` of size $(n_b,n_h,p,p)$\n", "$$\n", " \\tilde U_{s,t}^{b,h} \\coloneqq \\frac{U_{s,t}^{b,h}}{\\sum_{t'} U_{s,t'}^{b,h}}\n", " \\quad\\text{where}\\quad\n", " U_{s,t}^{b,h} \\coloneqq e^{\\frac{D_{s,t}^{b,h}}{\\sqrt{d_h}}}.\n", "$$\n", "The $1/\\sqrt{d_h}$ scaling is such that, at initialization, if $(K,Q)$ are Gaussian white noise with unit variance, then the entries of $\\tilde U_{s,t}^h$ have roughly the same amplitude, which is important to ease training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "background_save": true }, "id": "GmmQmGTIkgIo" }, "outputs": [], "source": [ "r = torch.sqrt(torch.tensor(d_h).double()) # note that this is the per-head dimension and not the full attention dimension\n", "U = torch.exp(D / r)\n", "Ut = U / torch.sum(U, axis=3, keepdim=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "J_iJB21YkgIo" }, "source": [ "This kernel is then used to barycenter the values points to obtains new points\n", "$$\n", " \\forall s = 0,\\ldots,p-1, \\quad\n", " z_{s}^{b,h} \\coloneqq \\sum_{t=0}^{p-1} \\tilde U_{s,t}^{b,h} v_t^b.\n", "$$\n", "These new points are stored in the array `Z` of size $(n_b,p,n_h,d_h)$." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q5tjoIWCkgIo" }, "outputs": [], "source": [ "Z = torch.einsum(\"bhst,bthi->bshi\", [Ut, VX])" ] }, { "cell_type": "markdown", "metadata": { "id": "fUpWQMKRkgIo" }, "source": [ "The output of all the heads are then grouped in new points\n", "$$\n", " \\forall s = 0,\\ldots,p-1, \\quad\n", " z_{s}^{b} \\coloneqq (z_{s}^{b,0},\\ldots,z_{s}^{b,n_h-1}) \\in \\mathbb{R}^d.\n", "$$\n", "They are still stored in the same matrix `Z` of size $(n_b,p,d)$." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HFBH7YHakgIo" }, "outputs": [], "source": [ "Z = Z.reshape(n_b, p, n_h*d_h)" ] }, { "cell_type": "markdown", "metadata": { "id": "NaCu3ZUCkgIo" }, "source": [ "Then a final linear matrix $L \\in \\mathbb{R}^{d \\times d}$ is applied independantly to each point to obtain the output\n", "$$\n", " \\forall s = 0,\\ldots,p-1, \\quad\n", " y_{s}^{b} \\coloneqq L z_{s}^{b}.\n", "$$\n", "These points are output by the function in an array `Y` of the same size as `X`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2Ps99ZdakgIp" }, "outputs": [], "source": [ "Y = torch.einsum(\"ij,bsj->bsi\", [L, Z])" ] }, { "cell_type": "markdown", "metadata": { "id": "-zm1NG5HkgIp" }, "source": [ "Put all this in a function." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YJ3p0Ba6kgIp" }, "outputs": [], "source": [ "def multi_head_attention(X, K, Q, V, L, n_h):\n", " n_b, p, d = X.shape\n", " d_h = d // n_h # dimension of the features of each head\n", " assert( d_h * n_h == d ), \"Embedding size needs to be divisible by heads\"\n", " # apply the matrices K,Q,V to X, and then spread them in the different heads\n", " KX = torch.einsum(\"ij,bsj->bsi\", [K, X]).reshape( n_b, p, n_h, d_h )\n", " QX = torch.einsum(\"ij,bsj->bsi\", [Q, X]).reshape( n_b, p, n_h, d_h )\n", " VX = torch.einsum(\"ij,bsj->bsi\", [V, X]).reshape( n_b, p, n_h, d_h )\n", " # compute \n", " D = torch.einsum(\"bshi,bthi->bhst\", [QX, KX])\n", " # scaled kernel\n", " r = torch.sqrt(torch.tensor(d_h).double())\n", " U = torch.exp(D / r)\n", " # row normalize (softmax)\n", " Ut = U / torch.sum(U, axis=3)[:,:,:,None]\n", " # apply kernel\n", " Z = torch.einsum(\"bhst,bthi->bshi\", [Ut, VX]).reshape(n_b, p, n_h*d_h)\n", " # apply final linear layer\n", " return torch.einsum(\"ij,bsj->bsi\", [L, Z])" ] }, { "cell_type": "markdown", "metadata": { "id": "funvuL0XkgIp" }, "source": [ "Compare the Pytorch implementation with out own." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wIVhYALbkgIp", "outputId": "1c43dc70-d814-4179-d3de-61f52e873251" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.9810291e-07\n" ] } ], "source": [ "# using pytorch code\n", "M = torch.nn.MultiheadAttention(d, n_h, batch_first=True, dropout=0.0, bias=False, device=device) # make sure to use the batch_first arg. according to your data layout\n", "Y_torch,_ = M(X, X, X) # self attention\n", "\n", "# Retrieve the Q, K, V matrices\n", "Q = M.in_proj_weight[:d, :]\n", "K = M.in_proj_weight[d:2*d, :]\n", "V = M.in_proj_weight[2*d:, :]\n", "# final projection matrix\n", "L = M.out_proj.weight\n", "\n", "# using our own code\n", "Y = multi_head_attention(X, K, Q, V, L, n_h)\n", "\n", "# should be 0 ...\n", "print((torch.norm(Y_torch - Y) /torch.norm(Y)).detach().cpu().numpy() )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "afbNwNKUkgIq" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [], "include_colab_link": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3.9.12 ('base')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "vscode": { "interpreter": { "hash": "1bd775ba363a980e8663c4b5e6fe16a4d2483fcdf14e1d1ee576e4bf99bce45c" } } }, "nbformat": 4, "nbformat_minor": 0 }