Published: January 9, 2023
6
75
596

Let's talk about JAX's vmap! It's a transformation that can automatically create vectorized, batched versions of your functions... but what exactly it does is sometimes misunderstood. So let's dig-in!

Image in tweet by Jake VanderPlas

Suppose you've implemented a model that maps a vector input to a scalar output. As an example, here's a simple function similar to a single neuron in a neural net:

Image in tweet by Jake VanderPlas

This function accepts a single length-5 vector of inputs, and outputs a scalar:

Image in tweet by Jake VanderPlas

Now, suppose you want to apply this model across a 2D array, where each row of the array is an input. Passing this batched data directly leads to an error:

Image in tweet by Jake VanderPlas

This error arises because our function is not defined in a way that can handle batched input. So what do we do? The easiest approach might be to use a simple Python list comprehension:

Image in tweet by Jake VanderPlas

This works, of course, but if you're familiar with NumPy-style computing in Python you'll immediately recognize the problem: loops in Python are typically slow compared to the native vectorized operations offered by NumPy & JAX.

In the old days, you'd have to re-write your model to explicitly accept batched data. This sometimes takes some thought, for example here the simple matrix product becomes an Einstein summation:

Image in tweet by Jake VanderPlas

As models get more complex, this sort of manual batchification can be complicated and error-prone. This is where jax.vmap comes in: it can transform your function into an efficient and correct batched version automatically!

Image in tweet by Jake VanderPlas

You might ask now which approach is more efficient: surely vmap must come at a cost? In most cases, however, vmap will produce virtually identical operations as the manual implementation, which we can see by printing the jaxpr (JAX's internal function representation) for each.

Image in tweet by Jake VanderPlas
Image in tweet by Jake VanderPlas

The details differ slightly — for example, xla_call comes from the fact that einsum is jit compiled — but the essential steps in the computation match more-or-less exactly: dot_general(), then add(), then tanh(), then reduce_sum().

Image in tweet by Jake VanderPlas

And this is what jax.vmap gives you: a way to automatically create efficient batched versions of your functions – that will lower to fast vectorized computations – without having to re-write your code by hand. You can read more about vmap in the JAX docs: https://jax.readthedocs.io/en/...

For runnable versions of all the code snippets in this thread, see jax_vmap.ipynb in this gist: https://gist.github.com/jakevd...

Share this thread

Read on Twitter

View original thread

Navigate thread

1/12