{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "## JAX vmap\n", "\n", "This is the source material for a tweet thread I did recently: https://twitter.com/jakevdp/status/1612544608646606849\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/jakevdp/467da4f567d34c59c1f34559790ef85f)" ], "metadata": { "id": "4sEj41C8MWRj" } }, { "cell_type": "markdown", "source": [ "---\n", "Let's talk about JAX's vmap! It's a transformation that can automatically create vectorized, batched versions of your functions... but what exactly it does is sometimes misunderstood. So let's dig-in!\n", "\n", "\n", "\n", "\n", "```python\n", "from jax import vmap\n", "```\n", "\n", "\n" ], "metadata": { "id": "C-rrr7PTnWf3" } }, { "cell_type": "markdown", "source": [ "---\n", "Suppose you've implemented a model that maps a vector input to a scalar output. As an example, here's a simple function similar to a single neuron in a neural net:" ], "metadata": { "id": "HSUIp2U_c-na" } }, { "cell_type": "code", "source": [ "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "\n", "rng = np.random.RandomState(8675309) # PRNGenny\n", "W = rng.randn(3, 5) # weights\n", "b = 1.0 # bias\n", "\n", "def model(v, W=W, b=b):\n", " return jnp.tanh(W @ v + b).sum()" ], "metadata": { "id": "v_3lI5DxrCWL" }, "execution_count": 1, "outputs": [] }, { "cell_type": "markdown", "source": [ "---\n", "This function accepts a single length-5 vector of inputs, and outputs a scalar:" ], "metadata": { "id": "-qoIeQhztUVS" } }, { "cell_type": "code", "source": [ "v = rng.randn(5)\n", "print(model(v))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ydrO9URuuN3O", "outputId": "124cef9d-4b3e-4d64-ec14-69bd73f491fd" }, "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "2.0699806\n" ] } ] }, { "cell_type": "markdown", "source": [ "---\n", "Now, suppose you want to apply this model across a 2D array, where each row of the array is an input. Passing this batched data directly leads to an error:" ], "metadata": { "id": "obtKROhinUnQ" } }, { "cell_type": "code", "source": [ "# This tells Jupyter to print one-line summaries of exceptions.\n", "%xmode minimal" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "RdETKwxQzmRa", "outputId": "be02c84c-a159-442e-d1bb-362f51e03825" }, "execution_count": 12, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Exception reporting mode: Minimal\n" ] } ] }, { "cell_type": "code", "source": [ "v_batch = rng.randn(4, 5) # 4 batches\n", "model(v_batch)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 97 }, "id": "6DUvyosYuw9i", "outputId": "029301bb-da66-4f1e-d839-828a6abea55b" }, "execution_count": 5, "outputs": [ { "output_type": "error", "ename": "ValueError", "evalue": "ignored", "traceback": [ "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 4 is different from 5)\n" ] } ] }, { "cell_type": "markdown", "source": [ "---\n", "This error arises because our function is not defined in a way that can handle batched input. So what do we do? The easiest approach might be to use a simple Python list comprehension:" ], "metadata": { "id": "FDa9OAzMvAH7" } }, { "cell_type": "code", "source": [ "jnp.array([model(v) for v in v_batch])" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NFWmiu3EvOSX", "outputId": "e55cb1c7-fa44-4042-d2c1-8f7855e403d6" }, "execution_count": 6, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "DeviceArray([-2.263083 , -1.4514356, 0.9401485, 2.9187164], dtype=float32)" ] }, "metadata": {}, "execution_count": 6 } ] }, { "cell_type": "markdown", "source": [ "---\n", "This works, of course, but if you're familiar with NumPy-style computing in Python you'll immediately recognize the problem: loops in Python are typically slow compared to the native vectorized operations offered by NumPy & JAX." ], "metadata": { "id": "GDc2O5BOu67p" } }, { "cell_type": "markdown", "source": [ "---\n", "In the old days, you'd have to re-write your model to explicitly accept batched data. This sometimes takes some thought, for example here the simple matrix product becomes an Einstein summation:" ], "metadata": { "id": "d8fy5VFSv8Vs" } }, { "cell_type": "code", "source": [ "def batched_model(v_batch, W=W, b=b):\n", " # Here are the dimensions for the batched matrix product:\n", " # W: (m, k)\n", " # v_batch: (n_batches, k)\n", " # output: (n_batches, m)\n", " return jnp.tanh(jnp.einsum(\"mk,nk->nm\", W, v_batch) + b).sum(1)\n", "\n", "# Results should match!\n", "print(jnp.array([model(v) for v in v_batch]))\n", "print(batched_model(v_batch))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "GHZxQ4DlwIHE", "outputId": "e4dd3798-17da-4215-e347-58e938d004cd" }, "execution_count": 7, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[-2.263083 -1.4514356 0.9401485 2.9187164]\n", "[-2.263083 -1.4514352 0.9401484 2.9187164]\n" ] } ] }, { "cell_type": "markdown", "source": [ "---\n", "As models get more complex, this sort of manual batchification can be complicated and error-prone. This is where jax.vmap comes in: it can transform your function into an efficient and correct batched version automatically!" ], "metadata": { "id": "aepo4NQHwt3H" } }, { "cell_type": "code", "source": [ "from jax import vmap\n", "\n", "print(batched_model(v_batch)) # manual batching\n", "print(vmap(model)(v_batch)) # automatic batching!" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZrFT2m7DxxEK", "outputId": "df9a3130-7f35-46ce-be91-524857889481" }, "execution_count": 8, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[-2.263083 -1.4514352 0.9401484 2.9187164]\n", "[-2.263083 -1.4514351 0.9401484 2.9187164]\n" ] } ] }, { "cell_type": "markdown", "source": [ "---\n", "You might ask now which approach is more efficient: surely vmap must come at a cost? In most cases, however, vmap will produce virtually identical operations as the manual implementation, which we can see by printing the jaxpr (JAX's internal function representation) for each." ], "metadata": { "id": "AzHxQrUkyAFV" } }, { "cell_type": "code", "source": [ "jax.make_jaxpr(batched_model)(v_batch)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "uyhwf0NOzu2O", "outputId": "a772259f-ddbb-4391-93e8-12b72e27ba9d" }, "execution_count": 10, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{ lambda a:f32[3,5]; b:f32[4,5]. let\n", " c:f32[4,3] = xla_call[\n", " call_jaxpr={ lambda ; d:f32[3,5] e:f32[4,5]. let\n", " f:f32[4,3] = dot_general[\n", " dimension_numbers=(((1,), (1,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None\n", " ] e d\n", " in (f,) }\n", " name=_einsum\n", " ] a b\n", " g:f32[4,3] = add c 1.0\n", " h:f32[4,3] = tanh g\n", " i:f32[4] = reduce_sum[axes=(1,)] h\n", " in (i,) }" ] }, "metadata": {}, "execution_count": 10 } ] }, { "cell_type": "code", "source": [ "jax.make_jaxpr(vmap(model))(v_batch)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QO4yt0ahywiB", "outputId": "c82a8ea8-e38a-4e43-8ad8-9cc07792e5d7" }, "execution_count": 11, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{ lambda a:f32[3,5]; b:f32[4,5]. let\n", " c:f32[3,4] = dot_general[\n", " dimension_numbers=(((1,), (1,)), ((), ()))\n", " precision=None\n", " preferred_element_type=None\n", " ] a b\n", " d:f32[3,4] = add c 1.0\n", " e:f32[3,4] = tanh d\n", " f:f32[4] = reduce_sum[axes=(0,)] e\n", " in (f,) }" ] }, "metadata": {}, "execution_count": 11 } ] }, { "cell_type": "markdown", "source": [ "---\n", "The details differ slightly — for example, xla_call comes from the fact that einsum is jit compiled — but the essential steps in the computation match more-or-less exactly: dot_general(), then add(), then tanh(), then reduce_sum().\n", "\n", "
\n",
        "{ lambda a:f32[3,5]; b:f32[4,5]. let                      { lambda a:f32[3,5]; b:f32[4,5]. let\n",
        "    c:f32[4,3] = xla_call[                                    c:f32[3,4] = dot_general[\n",
        "      call_jaxpr={ lambda ; d:f32[3,5] e:f32[4,5]. let          dimension_numbers=(((1,), (1,)), ((), ()))\n",
        "          f:f32[4,3] = dot_general[                             precision=None\n",
        "            dimension_numbers=(((1,), (1,)), ((), ()))          preferred_element_type=None\n",
        "            precision=None                                    ] a b\n",
        "            preferred_element_type=None                       d:f32[3,4] = add c 1.0\n",
        "          ] e d                                               e:f32[3,4] = tanh d\n",
        "        in (f,) }                                             f:f32[4] = reduce_sum[axes=(0,)] e\n",
        "      name=_einsum                                          in (f,) }\n",
        "    ] a b\n",
        "    g:f32[4,3] = add c 1.0\n",
        "    h:f32[4,3] = tanh g\n",
        "    i:f32[4] = reduce_sum[axes=(1,)] h\n",
        "  in (i,) }\n",
        "
" ], "metadata": { "id": "hupLvslAz8o2" } }, { "cell_type": "markdown", "source": [ "---\n", "And this is what jax.vmap gives you: a way to automatically create efficient batched versions of your functions – that will lower to fast vectorized computations – without having to re-write your code by hand.\n", "\n", "You can read more about vmap and related transforms in the JAX docs: https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html" ], "metadata": { "id": "yVbYunFrddch" } } ] }