"Open

In [None]:
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if torch.backends.mps.is_available(): device = torch.device("mps") # special for Mac

### Multi-head Attention from Scratch

This notebook implements from scratch, in a step-by-step fashion, a multi-head self-attention layer, which gives the same output as the Pytorch implementation.

References:
- Original transformer paper: [Attention is all you need](https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
- PyTorch implementation is in the function [`multi_head_attention_forward`](
https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py)

by [Zaccharie Ramzi](https://zaccharieramzi.fr/) and [Gabriel Peyré](http://www.gpeyre.com/).

The data is composed of batches of $p$ points $(x_s^b)_{s=0}^{p-1}$ in $\mathbb{R}^d$ stored in a matrix `X` of size $(n_b,p,d)$. Here $n_b$ is the number of batches, so that $b$ runs in $0 \ldots n_b-1$. These $n_b$ batches are processed in parallel. Note that we use here the "batch first" format.

In [None]:
n_b = 8 # size of batch, processed in parallel
p = 80 # number of points in each points cloud
d = 12 # dimension of the points
X = torch.randn(n_b, p, d, device=device)

Generate the parameter of the attention layer.

In [None]:
K = torch.randn(d, d, device=device)
Q = torch.randn(d, d, device=device)
V = torch.randn(d, d, device=device)
L = torch.randn(d, d, device=device)

$\newcommand{\coloneqq}{:=}$
First the points are transformed in Keys, Queries, Values using matrices $K \in \mathbb{R}^{d \times d}$, $Q \in \mathbb{R}^{d \times d}$, $V \in \mathbb{R}^{d \times d}$
$$
 \forall s = 0,\ldots,p-1, \quad
 k_s^b \coloneqq K x_i^b, \quad
 q_s^b \coloneqq Q x_i^b, \quad
 v_s^b \coloneqq V x_i^b.
$$
These points are stored in the arrays `KX,QX,VX` of size $(n_b,p,d)$.

We use Einstein summation notations to compute the transform, this is very useful and should be prefered over direct array manipulation (transposition, etc).

In [None]:
KX = torch.einsum("ij,bsj->bsi", [K, X])
QX = torch.einsum("ij,bsj->bsi", [Q, X])
VX = torch.einsum("ij,bsj->bsi", [V, X])

Then each of these points such as $k_s^b \in \mathbb{R}^{d}$ are split into $n_h$ ("number of heads") points $k_{s}^{b,h} \in \mathbb{R}^{d_h}$ where $d_h \coloneqq d/n_h$, i.e.
$$
 k_s^b = (k_{s}^{b,0},\ldots,k_{s}^{b,n_h-1}), \quad
 q_s^b = (k_{s}^{b,0},\ldots,k_{s}^{b,n_h-1}), \quad
 v_s^b = (k_{s}^{b,0},\ldots,k_{s}^{b,n_h-1}).
$$
These new points are still stored in the same `KX,QX,VX`, but they have size $(n_b,p,n_h,d_h)$.

In [None]:
n_h = 2 # number of heads
d_h = d // n_h # dimension of each head
KX = KX.reshape(n_b, p, n_h, d_h )
QX = QX.reshape(n_b, p, n_h, d_h )
VX = VX.reshape(n_b, p, n_h, d_h )

We then compute, for each head $h=0,\ldots,n_h-1$, the inner products between the keys and queries
$$
 \forall (s,t) \in \{0,\ldots,p-1\}^2, \quad
 D_{s,t}^{b,h} \coloneqq \langle k_{s}^{b,h}, q_{t}^{b,h} \rangle_{\mathbb{R}^{d_h}}
$$
and they are stored in the matrix `D` of size $(n_b,n_h,p,p)$.

In [None]:
D = torch.einsum("bshi,bthi->bhst", [QX, KX])

From these, one compute the attention kernel $U$ and row-normalize it to obtain $\tilde U$ stored in `Ut` of size $(n_b,n_h,p,p)$
$$
 \tilde U_{s,t}^{b,h} \coloneqq \frac{U_{s,t}^{b,h}}{\sum_{t'} U_{s,t'}^{b,h}}
 \quad\text{where}\quad
 U_{s,t}^{b,h} \coloneqq e^{\frac{D_{s,t}^{b,h}}{\sqrt{d_h}}}.
$$
The $1/\sqrt{d_h}$ scaling is such that, at initialization, if $(K,Q)$ are Gaussian white noise with unit variance, then the entries of $\tilde U_{s,t}^h$ have roughly the same amplitude, which is important to ease training.

In [None]:
r = torch.sqrt(torch.tensor(d_h).double()) # note that this is the per-head dimension and not the full attention dimension
U = torch.exp(D / r)
Ut = U / torch.sum(U, axis=3, keepdim=True)

This kernel is then used to barycenter the values points to obtains new points
$$
 \forall s = 0,\ldots,p-1, \quad
 z_{s}^{b,h} \coloneqq \sum_{t=0}^{p-1} \tilde U_{s,t}^{b,h} v_t^b.
$$
These new points are stored in the array `Z` of size $(n_b,p,n_h,d_h)$.

In [None]:
Z = torch.einsum("bhst,bthi->bshi", [Ut, VX])

The output of all the heads are then grouped in new points
$$
 \forall s = 0,\ldots,p-1, \quad
 z_{s}^{b} \coloneqq (z_{s}^{b,0},\ldots,z_{s}^{b,n_h-1}) \in \mathbb{R}^d.
$$
They are still stored in the same matrix `Z` of size $(n_b,p,d)$.

In [None]:
Z = Z.reshape(n_b, p, n_h*d_h)

Then a final linear matrix $L \in \mathbb{R}^{d \times d}$ is applied independantly to each point to obtain the output
$$
 \forall s = 0,\ldots,p-1, \quad
 y_{s}^{b} \coloneqq L z_{s}^{b}.
$$
These points are output by the function in an array `Y` of the same size as `X`.

In [None]:
Y = torch.einsum("ij,bsj->bsi", [L, Z])

Put all this in a function.

In [None]:
def multi_head_attention(X, K, Q, V, L, n_h):
 n_b, p, d = X.shape
 d_h = d // n_h # dimension of the features of each head
 assert( d_h * n_h == d ), "Embedding size needs to be divisible by heads"
 # apply the matrices K,Q,V to X, and then spread them in the different heads
 KX = torch.einsum("ij,bsj->bsi", [K, X]).reshape( n_b, p, n_h, d_h )
 QX = torch.einsum("ij,bsj->bsi", [Q, X]).reshape( n_b, p, n_h, d_h )
 VX = torch.einsum("ij,bsj->bsi", [V, X]).reshape( n_b, p, n_h, d_h )
 # compute 
 D = torch.einsum("bshi,bthi->bhst", [QX, KX])
 # scaled kernel
 r = torch.sqrt(torch.tensor(d_h).double())
 U = torch.exp(D / r)
 # row normalize (softmax)
 Ut = U / torch.sum(U, axis=3)[:,:,:,None]
 # apply kernel
 Z = torch.einsum("bhst,bthi->bshi", [Ut, VX]).reshape(n_b, p, n_h*d_h)
 # apply final linear layer
 return torch.einsum("ij,bsj->bsi", [L, Z])

Compare the Pytorch implementation with out own.

In [None]:
# using pytorch code
M = torch.nn.MultiheadAttention(d, n_h, batch_first=True, dropout=0.0, bias=False, device=device) # make sure to use the batch_first arg. according to your data layout
Y_torch,_ = M(X, X, X) # self attention

# Retrieve the Q, K, V matrices
Q = M.in_proj_weight[:d, :]
K = M.in_proj_weight[d:2*d, :]
V = M.in_proj_weight[2*d:, :]
# final projection matrix
L = M.out_proj.weight

# using our own code
Y = multi_head_attention(X, K, Q, V, L, n_h)

# should be 0 ...
print((torch.norm(Y_torch - Y) /torch.norm(Y)).detach().cpu().numpy() )

1.9810291e-07
