JAX's {class}`jax.Array` object is designed with distributed data and computation in mind.
This section will cover three modes of parallel computation:
- Automatic parallelism via {func}`jax.jit`, in which we let the compiler choose the optimal computation strategy
- Semi-automatic parallelism using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`
- Fully manual parallelism using {func}`jax.experimental.shard_map.shard_map`
These examples will be run on Colab's free TPU runtime, which provides eight devices to work with:
```{code-cell}
:outputId: 18905ae4-7b5e-4bb9-acb4-d8ab914cb456
import jax
jax.devices()
```
## Key concept: data sharding
Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.
Each concrete {class}`jax.Array` object has a `sharding` attribute and a `devices()` method that can give you insight into how the underlying data are stored. In the simplest cases, arrays are sharded on a single device:
```{code-cell}
:outputId: 39fdbb79-d5c0-4ea6-8b20-88b2c502a27a
import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()
```
```{code-cell}
:outputId: 536f773a-7ef4-4526-c58b-ab4d486bf5a1
arr.sharding
```
For a more visual representation of the storage layout, the {mod}`jax.debug` module provides some helpers to visualize the sharding of an array:
```{code-cell}
:outputId: 74a793e9-b13b-4d07-d8ec-7e25c547036d
jax.debug.visualize_array_sharding(arr)
```
To create an array with a non-trivial sharding, we can define a `sharding` specification for the array and pass this to {func}`jax.device_put`.
Here we'll define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimensional grid of devices with named axes:
```{code-cell}
:outputId: 0b397dba-3ddc-4aca-f002-2beab7e6b8a5
# Pardon the boilerplate; constructing a sharding will become easier soon!
Passing this `sharding` to {func}`jax.device_put`, we obtain a sharded array:
```{code-cell}
:outputId: c8ceedba-05ca-4156-e6e4-1e98bb664a66
arr_sharded = jax.device_put(arr, sharding)
print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
```
The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.
## Automatic parallelism via `jit`
Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a JIT-compiled function!
The XLA compiler behind `jit` includes heuristics for optimizing computations across multiple devices.
In the simplest of cases, those heuristics boil down to *computation follows data*.
For example, here's a simple element-wise function: the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:
As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data.
Here we sum along the leading axis of `x`:
```{code-cell}
:outputId: 90c3b997-3653-4a7b-c8ff-12a270f11d02
@jax.jit
def f_contract(x):
return x.sum(axis=0)
result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
```
The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.
## Semi-automated sharding with constraints
If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function.
For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:
This gives you a function with the particular output sharding you'd like.
## Manual parallelism with `shard_map`
In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices.
By contrast, with `shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.
`shard_map` works by mapping a function across a particular *mesh* of devices:
```{code-cell}
:outputId: 435c32f3-557a-4676-c11b-17e6bab8c1e2
from jax.experimental.shard_map import shard_map
P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')
f_elementwise_sharded = shard_map(
f_elementwise,
mesh=mesh,
in_specs=P('x'),
out_specs=P('x'))
arr = jnp.arange(32)
f_elementwise_sharded(arr)
```
The function you write only "sees" a single batch of the data, which we can see by printing the device local shape:
```{code-cell}
:outputId: 99a3dc6e-154a-4ef6-8eaa-3dd0b68fb1da
x = jnp.arange(32)
print(f"global shape: {x.shape=}")
def f(x):
print(f"device local shape: {x.shape=}")
return x * 2
y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
```
Because each of your functions only sees the device-local part of the data, it means that aggregation-like functions require some extra thought.
For example, here's what a `shard_map` of a `sum` looks like: