Initial commit of third JAX-101 notebook

This commit is contained in:
Jake VanderPlas 2021-03-02 15:00:49 -08:00
parent e4b8fec287
commit 4976aee01b
3 changed files with 538 additions and 0 deletions

View File

@ -0,0 +1,376 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "zMIrmiaZxiJC"
},
"source": [
"# Automatic vectorization in JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/03-vectorization.ipynb)\n",
"\n",
"*Authors: TODO*\n",
"\n",
"In the previous section we discussed JIT compilation via the `jax.jit` function. This notebook discusses another of JAX's transforms: vectorization via `jax.vmap`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Kw-_imBrx4nN"
},
"source": [
"## Manual Vectorization\n",
"\n",
"Consider the following simple code that computes the convolution of two one-dimensional vectors:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "5Obro91lwE_s",
"outputId": "061983c6-2faa-4a54-83a5-d2a823f61087"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"data": {
"text/plain": [
"DeviceArray([11., 20., 29.], dtype=float32)"
]
},
"execution_count": 1,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"x = jnp.arange(5)\n",
"w = jnp.array([2., 3., 4.])\n",
"\n",
"def convolve(x, w):\n",
" output = []\n",
" for i in range(1, len(x)-1):\n",
" output.append(jnp.dot(x[i-1:i+2], w))\n",
" return jnp.array(output)\n",
"\n",
"convolve(x, w)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z_nPhEhLRysk"
},
"source": [
"Suppose we would like to apply this function to a batch of weights `w` to a batch of vectors `x`."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "rHQJnnrVUbxE"
},
"outputs": [],
"source": [
"xs = jnp.stack([x, x])\n",
"ws = jnp.stack([w, w])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ghaJQW1aUfPi"
},
"source": [
"The most naive option would be to simply loop over the batch in Python:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "yM-IycdlzGyJ",
"outputId": "07ed6ffc-0265-45ef-d585-4b5fa7d221f1"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
"execution_count": 10,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"def manually_batched_convolve(xs, ws):\n",
" output = []\n",
" for i in range(xs.shape[0]):\n",
" output.append(convolve(xs[i], ws[i]))\n",
" return jnp.stack(output)\n",
"\n",
"manually_batched_convolve(xs, ws)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VTh0l_1SUlh4"
},
"source": [
"This produces the correct result, however it is not very efficient.\n",
"\n",
"In order to batch the computation efficiently, you would normally have to rewrite the function manually to ensure it is done in vectorized form. This is not particularly difficult to implement, but does involve changing how the function treats indices, axes, and other parts of the input.\n",
"\n",
"For example, we could manually rewrite `convolve()` to support vectorized computation across the batch dimension as follows:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "I4Wd9nrcTRRL",
"outputId": "0b037b43-7b41-4625-f9e0-a6e0dbc4c65a"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
"execution_count": 5,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"def manually_vectorized_convolve(xs, ws):\n",
" output = []\n",
" for i in range(1, xs.shape[-1] -1):\n",
" output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))\n",
" return jnp.stack(output, axis=1)\n",
"\n",
"manually_vectorized_convolve(xs, ws)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DW-RJ2Zs2QVu"
},
"source": [
"Such re-implementation is messy and error-prone; fortunately JAX provides another way."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2oVLanQmUAo_"
},
"source": [
"## Automatic Vectorization\n",
"\n",
"In JAX, the `jax.vmap` transformation is designed to generate such a vectorized implementation of a function automatically:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "Brl-BoTqSQDw",
"outputId": "af608dbb-27f2-4fbc-f225-79f3101b13ff"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
"execution_count": 6,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"auto_batch_convolve = jax.vmap(convolve)\n",
"\n",
"auto_batch_convolve(xs, ws)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7aVAy7332lFj"
},
"source": [
"It does this by tracing the function similarly to `jax.jit`, and automatically adding batch axes at the beginning of each input.\n",
"\n",
"If the batch dimension is not the first, you may use the `in_axes` and `out_axes` arguments to specify the location of the batch dimension in inputs and outputs. These may be an integer if the batch axis is the same for all inputs and outputs, or lists, otherwise."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "_VEEm1CGT2n0",
"outputId": "751e0fbf-bdfb-41df-9436-4da5de23123f"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[11., 11.],\n",
" [20., 20.],\n",
" [29., 29.]], dtype=float32)"
]
},
"execution_count": 7,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)\n",
"\n",
"xst = jnp.transpose(xs)\n",
"wst = jnp.transpose(ws)\n",
"\n",
"auto_batch_convolve_v2(xst, wst)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-gNiLuxzSX32"
},
"source": [
"`jax.vmap` also supports the case where only one of the arguments is batched: for example, if you would like to convolve to a single set of weights `w` with a batch of vectors `x`; in this case the `in_axes` argument can be set to `None`:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "2s2YDsamSxki",
"outputId": "5c70879b-5cce-4549-e38a-f45dbe663ab2"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
"execution_count": 8,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])\n",
"\n",
"batch_convolve_v3(xs, w)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bsxT4hA6RTCG"
},
"source": [
"## Combining transformations\n",
"\n",
"As with all JAX transformations, `jax.jit` and `jax.vmap` are designed to be composable, which means you can wrap a vmapped function with `jit`, or a JITted function with `vmap`, and everything will work correctly:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "gsC-Myg0RVdj",
"outputId": "cbdd384e-6633-4cea-b1a0-a01ad934a768"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
"execution_count": 9,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"jitted_batch_convolve = jax.jit(auto_batch_convolve)\n",
"\n",
"jitted_batch_convolve(xs, ws)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Vectorization in JAX",
"provenance": []
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -0,0 +1,161 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.10.0
kernelspec:
display_name: Python 3
name: python3
---
+++ {"id": "zMIrmiaZxiJC"}
# Automatic vectorization in JAX
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/03-vectorization.ipynb)
*Authors: TODO*
In the previous section we discussed JIT compilation via the `jax.jit` function. This notebook discusses another of JAX's transforms: vectorization via `jax.vmap`.
+++ {"id": "Kw-_imBrx4nN"}
## Manual Vectorization
Consider the following simple code that computes the convolution of two one-dimensional vectors:
```{code-cell} ipython3
:id: 5Obro91lwE_s
:outputId: 061983c6-2faa-4a54-83a5-d2a823f61087
import jax
import jax.numpy as jnp
x = jnp.arange(5)
w = jnp.array([2., 3., 4.])
def convolve(x, w):
output = []
for i in range(1, len(x)-1):
output.append(jnp.dot(x[i-1:i+2], w))
return jnp.array(output)
convolve(x, w)
```
+++ {"id": "z_nPhEhLRysk"}
Suppose we would like to apply this function to a batch of weights `w` to a batch of vectors `x`.
```{code-cell} ipython3
:id: rHQJnnrVUbxE
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])
```
+++ {"id": "ghaJQW1aUfPi"}
The most naive option would be to simply loop over the batch in Python:
```{code-cell} ipython3
:id: yM-IycdlzGyJ
:outputId: 07ed6ffc-0265-45ef-d585-4b5fa7d221f1
def manually_batched_convolve(xs, ws):
output = []
for i in range(xs.shape[0]):
output.append(convolve(xs[i], ws[i]))
return jnp.stack(output)
manually_batched_convolve(xs, ws)
```
+++ {"id": "VTh0l_1SUlh4"}
This produces the correct result, however it is not very efficient.
In order to batch the computation efficiently, you would normally have to rewrite the function manually to ensure it is done in vectorized form. This is not particularly difficult to implement, but does involve changing how the function treats indices, axes, and other parts of the input.
For example, we could manually rewrite `convolve()` to support vectorized computation across the batch dimension as follows:
```{code-cell} ipython3
:id: I4Wd9nrcTRRL
:outputId: 0b037b43-7b41-4625-f9e0-a6e0dbc4c65a
def manually_vectorized_convolve(xs, ws):
output = []
for i in range(1, xs.shape[-1] -1):
output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
return jnp.stack(output, axis=1)
manually_vectorized_convolve(xs, ws)
```
+++ {"id": "DW-RJ2Zs2QVu"}
Such re-implementation is messy and error-prone; fortunately JAX provides another way.
+++ {"id": "2oVLanQmUAo_"}
## Automatic Vectorization
In JAX, the `jax.vmap` transformation is designed to generate such a vectorized implementation of a function automatically:
```{code-cell} ipython3
:id: Brl-BoTqSQDw
:outputId: af608dbb-27f2-4fbc-f225-79f3101b13ff
auto_batch_convolve = jax.vmap(convolve)
auto_batch_convolve(xs, ws)
```
+++ {"id": "7aVAy7332lFj"}
It does this by tracing the function similarly to `jax.jit`, and automatically adding batch axes at the beginning of each input.
If the batch dimension is not the first, you may use the `in_axes` and `out_axes` arguments to specify the location of the batch dimension in inputs and outputs. These may be an integer if the batch axis is the same for all inputs and outputs, or lists, otherwise.
```{code-cell} ipython3
:id: _VEEm1CGT2n0
:outputId: 751e0fbf-bdfb-41df-9436-4da5de23123f
auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)
xst = jnp.transpose(xs)
wst = jnp.transpose(ws)
auto_batch_convolve_v2(xst, wst)
```
+++ {"id": "-gNiLuxzSX32"}
`jax.vmap` also supports the case where only one of the arguments is batched: for example, if you would like to convolve to a single set of weights `w` with a batch of vectors `x`; in this case the `in_axes` argument can be set to `None`:
```{code-cell} ipython3
:id: 2s2YDsamSxki
:outputId: 5c70879b-5cce-4549-e38a-f45dbe663ab2
batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])
batch_convolve_v3(xs, w)
```
+++ {"id": "bsxT4hA6RTCG"}
## Combining transformations
As with all JAX transformations, `jax.jit` and `jax.vmap` are designed to be composable, which means you can wrap a vmapped function with `jit`, or a JITted function with `vmap`, and everything will work correctly:
```{code-cell} ipython3
:id: gsC-Myg0RVdj
:outputId: cbdd384e-6633-4cea-b1a0-a01ad934a768
jitted_batch_convolve = jax.jit(auto_batch_convolve)
jitted_batch_convolve(xs, ws)
```

View File

@ -10,6 +10,7 @@ It is a work in progress; more content will be added soon!
01-jax-basics 01-jax-basics
02-jitting 02-jitting
03-vectorization
.. _Deepmind: http://deepmind.com .. _Deepmind: http://deepmind.com