## JAX vmap

This is the source material for a tweet thread I did recently: https://twitter.com/jakevdp/status/1612544608646606849

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/jakevdp/467da4f567d34c59c1f34559790ef85f)

---
Let's talk about JAX's vmap! It's a transformation that can automatically create vectorized, batched versions of your functions... but what exactly it does is sometimes misunderstood. So let's dig-in!

<img src="https://jax.readthedocs.io/en/latest/_static/jax_logo_250px.png"/>
<font size=6>

```python
from jax import vmap
```

</font>


---
Suppose you've implemented a model that maps a vector input to a scalar output. As an example, here's a simple function similar to a single neuron in a neural net:

In [1]:
import jax
import jax.numpy as jnp
import numpy as np

rng = np.random.RandomState(8675309)  # PRNGenny
W = rng.randn(3, 5) # weights
b = 1.0             # bias

def model(v, W=W, b=b):
  return jnp.tanh(W @ v + b).sum()

---
This function accepts a single length-5 vector of inputs, and outputs a scalar:

In [3]:
v = rng.randn(5)
print(model(v))

2.0699806


---
Now, suppose you want to apply this model across a 2D array, where each row of the array is an input. Passing this batched data directly leads to an error:

In [12]:
# This tells Jupyter to print one-line summaries of exceptions.
%xmode minimal

Exception reporting mode: Minimal


In [5]:
v_batch = rng.randn(4, 5)  # 4 batches
model(v_batch)

ValueError: ignored

---
This error arises because our function is not defined in a way that can handle batched input. So what do we do? The easiest approach might be to use a simple Python list comprehension:

In [6]:
jnp.array([model(v) for v in v_batch])

DeviceArray([-2.263083 , -1.4514356,  0.9401485,  2.9187164], dtype=float32)

---
This works, of course, but if you're familiar with NumPy-style computing in Python you'll immediately recognize the problem: loops in Python are typically slow compared to the native vectorized operations offered by NumPy & JAX.

---
In the old days, you'd have to re-write your model to explicitly accept batched data. This sometimes takes some thought, for example here the simple matrix product becomes an Einstein summation:

In [7]:
def batched_model(v_batch, W=W, b=b):
  # Here are the dimensions for the batched matrix product:
  #  W:       (m, k)
  #  v_batch: (n_batches, k)
  #  output:  (n_batches, m)
  return jnp.tanh(jnp.einsum("mk,nk->nm", W, v_batch) + b).sum(1)

# Results should match!
print(jnp.array([model(v) for v in v_batch]))
print(batched_model(v_batch))

[-2.263083  -1.4514356  0.9401485  2.9187164]
[-2.263083  -1.4514352  0.9401484  2.9187164]


---
As models get more complex, this sort of manual batchification can be complicated and error-prone. This is where jax.vmap comes in: it can transform your function into an efficient and correct batched version automatically!

In [8]:
from jax import vmap

print(batched_model(v_batch))  # manual batching
print(vmap(model)(v_batch))    # automatic batching!

[-2.263083  -1.4514352  0.9401484  2.9187164]
[-2.263083  -1.4514351  0.9401484  2.9187164]


---
You might ask now which approach is more efficient: surely vmap must come at a cost? In most cases, however, vmap will produce virtually identical operations as the manual implementation, which we can see by printing the jaxpr (JAX's internal function representation) for each.

In [10]:
jax.make_jaxpr(batched_model)(v_batch)

{ lambda a:f32[3,5]; b:f32[4,5]. let
    c:f32[4,3] = xla_call[
      call_jaxpr={ lambda ; d:f32[3,5] e:f32[4,5]. let
          f:f32[4,3] = dot_general[
            dimension_numbers=(((1,), (1,)), ((), ()))
            precision=None
            preferred_element_type=None
          ] e d
        in (f,) }
      name=_einsum
    ] a b
    g:f32[4,3] = add c 1.0
    h:f32[4,3] = tanh g
    i:f32[4] = reduce_sum[axes=(1,)] h
  in (i,) }

In [11]:
jax.make_jaxpr(vmap(model))(v_batch)

{ lambda a:f32[3,5]; b:f32[4,5]. let
    c:f32[3,4] = dot_general[
      dimension_numbers=(((1,), (1,)), ((), ()))
      precision=None
      preferred_element_type=None
    ] a b
    d:f32[3,4] = add c 1.0
    e:f32[3,4] = tanh d
    f:f32[4] = reduce_sum[axes=(0,)] e
  in (f,) }

---
The details differ slightly — for example, xla_call comes from the fact that einsum is jit compiled — but the essential steps in the computation match more-or-less exactly: dot_general(), then add(), then tanh(), then reduce_sum().

<pre>
{ lambda a:f32[3,5]; b:f32[4,5]. let                      { lambda a:f32[3,5]; b:f32[4,5]. let
    c:f32[4,3] = xla_call[                                    c:f32[3,4] = <mark>dot_general</mark>[
      call_jaxpr={ lambda ; d:f32[3,5] e:f32[4,5]. let          dimension_numbers=(((1,), (1,)), ((), ()))
          f:f32[4,3] = <mark>dot_general</mark>[                             precision=None
            dimension_numbers=(((1,), (1,)), ((), ()))          preferred_element_type=None
            precision=None                                    ] a b
            preferred_element_type=None                       d:f32[3,4] = <mark>add</mark> c 1.0
          ] e d                                               e:f32[3,4] = <mark>tanh</mark> d
        in (f,) }                                             f:f32[4] = <mark>reduce_sum</mark>[axes=(0,)] e
      name=_einsum                                          in (f,) }
    ] a b
    g:f32[4,3] = <mark>add</mark> c 1.0
    h:f32[4,3] = <mark>tanh</mark> g
    i:f32[4] = <mark>reduce_sum</mark>[axes=(1,)] h
  in (i,) }
</pre>

---
And this is what jax.vmap gives you: a way to automatically create efficient batched versions of your functions – that will lower to fast vectorized computations – without having to re-write your code by hand.

You can read more about vmap and related transforms in the JAX docs: https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html