"This notebook is an introduction to writing single-program multiple-data (SPMD) programs in JAX, and executing them synchronously in parallel on multiple devices, such as multiple GPUs or multiple TPU cores. The SPMD model is useful for computations like training neural networks with synchronous gradient descent algorithms, and can be used for data-parallel as well as model-parallel computations.\n",
"\n",
"To run this notebook with any parallelism, you'll need multiple XLA devices available, e.g. with a multi-GPU machine or a Cloud TPU.\n",
"The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)."
"But `pmap` and `vmap` differ in in how those values are computed: where `vmap` vectorizes a function by adding a batch dimension to every primitive operation in the function (e.g. turning matrix-vector multiplies into matrix-matrix multiplies), `pmap` instead replicates the function and executes each replica on its own XLA device in parallel."
"Notice that applying `vmap(f)` to these arguments leads to a `dot_general` to express the batch matrix multiplication in a single primitive, while applying `pmap(f)` instead leads to a primitive that calls replicas of the original `f` in parallel.\n",
"An important constraint with using `pmap` is that ",
"the mapped axis size must be less than or equal to the number of XLA devices available (and for nested `pmap` functions, the product of the mapped axis sizes must be less than or equal to the number of XLA devices).\n",
"But while the output here acts just like a NumPy ndarray, if you look closely it has a different type:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "59hnyVOtfavX",
"colab_type": "code",
"colab": {}
},
"source": [
"y"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "4brdSdeyf2MP",
"colab_type": "text"
},
"source": [
"A `ShardedDeviceArray` is effectively an `ndarray` subclass, but it's stored in pieces spread across the memory of multiple devices. Results from `pmap` functions are left sharded in device memory so that they can be operated on by subsequent `pmap` functions without moving data around, at least in some cases. But these results logically appear just like a single array.\n",
"\n",
"When you call a non-`pmap` function on a `ShardedDeviceArray`, like a standard `jax.numpy` function, communication happens behind the scenes to bring the values to one device (or back to the host in the case of the matplotlib function above):"
"Thinking about device memory is important to maximize performance by avoiding data transfers, but you can always fall back to treating arraylike values as (read-only) NumPy ndarrays and your code will still work.\n",
"\n",
"Here's another example of a pure map which makes better use of our multiple-accelerator resources. We can generate several large random matrices in parallel, then perform parallel batch matrix multiplication without any cross-device movement of the large matrix data:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rWl68coLJSi7",
"colab_type": "code",
"colab": {}
},
"source": [
"from jax import random\n",
"\n",
"# create 8 random keys\n",
"keys = random.split(random.PRNGKey(0), 8)\n",
"# create a 5000 x 6000 matrix on each device by mapping over keys\n",
"In this example, the large matrices never had to be moved between devices or back to the host; only one scalar per device was pulled back to the host."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MdRscR5MONuN",
"colab_type": "text"
},
"source": [
"### Collective communication operations"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bFtajUwp5WYx",
"colab_type": "text"
},
"source": [
"In addition to expressing pure maps, where no communication happens between the replicated functions, with `pmap` you can also use special collective communication operations.\n",
"\n",
"One canonical example of a collective, implemented on both GPU and TPU, is an all-reduce sum like `lax.psum`:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "d5s8rJVUORQ3",
"colab_type": "code",
"colab": {}
},
"source": [
"from jax import lax\n",
"\n",
"normalize = lambda x: x / lax.psum(x, axis_name='i')\n",
"To use a collective operation like `lax.psum`, you need to supply an `axis_name` argument to `pmap`. The `axis_name` argument associates a name to the mapped axis so that collective operations can refer to it.\n",
"\n",
"Another way to write this same code is to use `pmap` as a decorator:"
"When writing nested `pmap` functions in the decorator style, axis names are resolved according to lexical scoping.\n",
"\n",
"Check [the JAX reference documentation](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n",
"\n",
"Here's how to use `lax.ppermute` to implement a simple halo exchange for a [Rule 30](https://en.wikipedia.org/wiki/Rule_30) simulation:"
"When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the backward pass of the computation is parallelized just like the forward-pass."