Merge pull request #21168 from 8bitmp3:upgrade-sharded--doc

PiperOrigin-RevId: 632648408
This commit is contained in:
jax authors 2024-05-10 17:44:15 -07:00
commit 979d9ca3e5
2 changed files with 106 additions and 80 deletions

View File

@ -7,15 +7,17 @@
"(sharded-computation)=\n",
"# Introduction to sharded computation\n",
"\n",
"JAX's {class}`jax.Array` object is designed with distributed data and computation in mind.\n",
"This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.\n",
"\n",
"This section will cover three modes of parallel computation:\n",
"The tutorial covers three modes of parallel computation:\n",
"\n",
"- Automatic parallelism via {func}`jax.jit`, in which we let the compiler choose the optimal computation strategy\n",
"- Semi-automatic parallelism using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`\n",
"- Fully manual parallelism using {func}`jax.experimental.shard_map.shard_map`\n",
"- _Automatic parallelism via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. \"the compiler takes the wheel\").\n",
"- _Semi-automated parallelism_ using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`\n",
"- _Fully manual parallelism with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives\n",
"\n",
"These examples will be run on Colab's free TPU runtime, which provides eight devices to work with:"
"Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices.\n",
"\n",
"If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU v2** (which provides eight devices to work with)."
]
},
{
@ -52,11 +54,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Key concept: data sharding\n",
"## Key concept: Data sharding\n",
"\n",
"Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.\n",
"\n",
"Each concrete {class}`jax.Array` object has a `sharding` attribute and a `devices()` method that can give you insight into how the underlying data are stored. In the simplest cases, arrays are sharded on a single device:"
"How can JAX can understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.\n",
"\n",
"In the simplest cases, arrays are sharded on a single device, as demonstrated below:"
]
},
{
@ -109,7 +113,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"For a more visual representation of the storage layout, the {mod}`jax.debug` module provides some helpers to visualize the sharding of an array:"
"For a more visual representation of the storage layout, the {mod}`jax.debug` module provides some helpers to visualize the sharding of an array. For example, {func}`jax.debug.visualize_array_sharding` displays how the array is stored in memory of a single device:"
]
},
{
@ -161,8 +165,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"To create an array with a non-trivial sharding, we can define a `sharding` specification for the array and pass this to {func}`jax.device_put`.\n",
"Here we'll define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimensional grid of devices with named axes:"
"To create an array with a non-trivial sharding, you can define a {mod}`jax.sharding` specification for the array and pass this to {func}`jax.device_put`.\n",
"\n",
"Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimensional grid of devices with named axes, where {class}`jax.sharding.Mesh` allows for precise device placement:"
]
},
{
@ -181,7 +186,7 @@
}
],
"source": [
"# Pardon the boilerplate; constructing a sharding will become easier soon!\n",
"# Pardon the boilerplate; constructing a sharding will become easier in future!\n",
"from jax.sharding import Mesh\n",
"from jax.sharding import PartitionSpec\n",
"from jax.sharding import NamedSharding\n",
@ -198,7 +203,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Passing this `sharding` to {func}`jax.device_put`, we obtain a sharded array:"
"Passing this `Sharding` object to {func}`jax.device_put`, you can obtain a sharded array:"
]
},
{
@ -267,14 +272,14 @@
"source": [
"The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n",
"\n",
"## 1. Automatic parallelism via `jit`\n",
"\n",
"Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.\n",
"\n",
"## Automatic parallelism via `jit`\n",
"Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a JIT-compiled function!\n",
"The XLA compiler behind `jit` includes heuristics for optimizing computations across multiple devices.\n",
"In the simplest of cases, those heuristics boil down to *computation follows data*.\n",
"\n",
"For example, here's a simple element-wise function: the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:"
"To demonstrate how auto-parallelization works in JAX, below is an example that uses a {func}`jax.jit`-decorated staged-out function: it's a simple element-wise function, where the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:"
]
},
{
@ -307,7 +312,8 @@
"metadata": {},
"source": [
"As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data.\n",
"Here we sum along the leading axis of `x`:"
"\n",
"Here, you sum along the leading axis of `x`, and visualize how the result values are stored across multiple devices (with {func}`jax.debug.visualize_array_sharding`):"
]
},
{
@ -356,11 +362,9 @@
"source": [
"The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n",
"\n",
"## 2. Semi-automated sharding with constraints\n",
"\n",
"\n",
"## Semi-automated sharding with constraints\n",
"\n",
"If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function.\n",
"If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n",
"\n",
"For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:"
]
@ -416,14 +420,16 @@
"source": [
"This gives you a function with the particular output sharding you'd like.\n",
"\n",
"## 3. Manual parallelism with `shard_map`\n",
"\n",
"In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.experimental.shard_map.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.\n",
"\n",
"## Manual parallelism with `shard_map`\n",
"`shard_map` works by mapping a function across a particular *mesh* of devices (`shard_map` maps over shards). In the example below:\n",
"\n",
"In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices.\n",
"By contrast, with `shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.\n",
"- As before, {class}`jax.sharding.Mesh` allows for precise device placement, with the axis names parameter for logical and physical axis names.\n",
"- The `in_specs` argument determines the shard sizes. The `out_specs` argument identifies how the blocks are assembled back together.\n",
"\n",
"`shard_map` works by mapping a function across a particular *mesh* of devices:"
"**Note:** {func}`jax.experimental.shard_map.shard_map` code can work inside {func}`jax.jit` if you need it."
]
},
{
@ -469,7 +475,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The function you write only \"sees\" a single batch of the data, which we can see by printing the device local shape:"
"The function you write only \"sees\" a single batch of the data, which you can check by printing the device local shape:"
]
},
{
@ -503,8 +509,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Because each of your functions only sees the device-local part of the data, it means that aggregation-like functions require some extra thought.\n",
"For example, here's what a `shard_map` of a `sum` looks like:"
"Because each of your functions only \"sees\" the device-local part of the data, it means that aggregation-like functions require some extra thought.\n",
"\n",
"For example, here's what a `shard_map` of a {func}`jax.numpy.sum` looks like:"
]
},
{
@ -536,8 +543,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Our function `f` operates separately on each shard, and the resulting summation reflects this.\n",
"If we want to sum across shards, we need to explicitly request it using collective operations like {func}`jax.lax.psum`:"
"Your function `f` operates separately on each shard, and the resulting summation reflects this.\n",
"\n",
"If you want to sum across shards, you need to explicitly request it using collective operations like {func}`jax.lax.psum`:"
]
},
{
@ -570,14 +578,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Because the output no longer has a sharded dimension, we set `out_specs=P()`.\n",
"\n",
"\n",
"Because the output no longer has a sharded dimension, set `out_specs=P()` (recall that the `out_specs` argument identifies how the blocks are assembled back together in `shard_map`).\n",
"\n",
"## Comparing the three approaches\n",
"\n",
"With these concepts fresh in our mind, let's compare the three approaches for a simple neural network layer.\n",
"We'll define our canonical function like this:"
"\n",
"Start by defining your canonical function like this:"
]
},
{
@ -626,8 +633,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data.\n",
"If we shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will autoatically happen in parallel:"
"You can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data.\n",
"\n",
"If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel:"
]
},
{
@ -663,7 +671,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternatively, we can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs:"
"Alternatively, you can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs:"
]
},
{
@ -698,7 +706,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we can do the same thing with `shard_map`, using `psum` to indicate the cross-shard collective required for the matrix product:"
"Finally, you can do the same thing with `shard_map`, using {func}`jax.lax.psum` to indicate the cross-shard collective required for the matrix product:"
]
},
{
@ -736,8 +744,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"This section has been a brief introduction of sharded and parallel computation;\n",
"for more discussion of `shard_map`, see {doc}`../notebooks/shard_map`."
"## Next steps\n",
"\n",
"This tutorial serves as a brief introduction of sharded and parallel computation in JAX.\n",
"\n",
"To learn about each SPMD method in-depth, check out these docs:\n",
"- {doc}`../notebooks/Distributed_arrays_and_automatic_parallelization`\n",
"- {doc}`../notebooks/shard_map`"
]
}
],

View File

@ -14,15 +14,17 @@ kernelspec:
(sharded-computation)=
# Introduction to sharded computation
JAX's {class}`jax.Array` object is designed with distributed data and computation in mind.
This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.
This section will cover three modes of parallel computation:
The tutorial covers three modes of parallel computation:
- Automatic parallelism via {func}`jax.jit`, in which we let the compiler choose the optimal computation strategy
- Semi-automatic parallelism using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`
- Fully manual parallelism using {func}`jax.experimental.shard_map.shard_map`
- _Automatic parallelism via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. "the compiler takes the wheel").
- _Semi-automated parallelism_ using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`
- _Fully manual parallelism with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives
These examples will be run on Colab's free TPU runtime, which provides eight devices to work with:
Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices.
If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU v2** (which provides eight devices to work with).
```{code-cell}
:outputId: 18905ae4-7b5e-4bb9-acb4-d8ab914cb456
@ -31,11 +33,13 @@ import jax
jax.devices()
```
## Key concept: data sharding
## Key concept: Data sharding
Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.
Each concrete {class}`jax.Array` object has a `sharding` attribute and a `devices()` method that can give you insight into how the underlying data are stored. In the simplest cases, arrays are sharded on a single device:
How can JAX can understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.
In the simplest cases, arrays are sharded on a single device, as demonstrated below:
```{code-cell}
:outputId: 39fdbb79-d5c0-4ea6-8b20-88b2c502a27a
@ -51,7 +55,7 @@ arr.devices()
arr.sharding
```
For a more visual representation of the storage layout, the {mod}`jax.debug` module provides some helpers to visualize the sharding of an array:
For a more visual representation of the storage layout, the {mod}`jax.debug` module provides some helpers to visualize the sharding of an array. For example, {func}`jax.debug.visualize_array_sharding` displays how the array is stored in memory of a single device:
```{code-cell}
:outputId: 74a793e9-b13b-4d07-d8ec-7e25c547036d
@ -59,13 +63,14 @@ For a more visual representation of the storage layout, the {mod}`jax.debug` mod
jax.debug.visualize_array_sharding(arr)
```
To create an array with a non-trivial sharding, we can define a `sharding` specification for the array and pass this to {func}`jax.device_put`.
Here we'll define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimensional grid of devices with named axes:
To create an array with a non-trivial sharding, you can define a {mod}`jax.sharding` specification for the array and pass this to {func}`jax.device_put`.
Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimensional grid of devices with named axes, where {class}`jax.sharding.Mesh` allows for precise device placement:
```{code-cell}
:outputId: 0b397dba-3ddc-4aca-f002-2beab7e6b8a5
# Pardon the boilerplate; constructing a sharding will become easier soon!
# Pardon the boilerplate; constructing a sharding will become easier in future!
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
@ -78,7 +83,7 @@ sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
```
Passing this `sharding` to {func}`jax.device_put`, we obtain a sharded array:
Passing this `Sharding` object to {func}`jax.device_put`, you can obtain a sharded array:
```{code-cell}
:outputId: c8ceedba-05ca-4156-e6e4-1e98bb664a66
@ -91,14 +96,14 @@ jax.debug.visualize_array_sharding(arr_sharded)
The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.
## 1. Automatic parallelism via `jit`
Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.
## Automatic parallelism via `jit`
Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a JIT-compiled function!
The XLA compiler behind `jit` includes heuristics for optimizing computations across multiple devices.
In the simplest of cases, those heuristics boil down to *computation follows data*.
For example, here's a simple element-wise function: the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:
To demonstrate how auto-parallelization works in JAX, below is an example that uses a {func}`jax.jit`-decorated staged-out function: it's a simple element-wise function, where the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:
```{code-cell}
:outputId: de46f86a-6907-49c8-f36c-ed835e78bc3d
@ -113,7 +118,8 @@ print("shardings match:", result.sharding == arr_sharded.sharding)
```
As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data.
Here we sum along the leading axis of `x`:
Here, you sum along the leading axis of `x`, and visualize how the result values are stored across multiple devices (with {func}`jax.debug.visualize_array_sharding`):
```{code-cell}
:outputId: 90c3b997-3653-4a7b-c8ff-12a270f11d02
@ -129,11 +135,9 @@ print(result)
The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.
## 2. Semi-automated sharding with constraints
## Semi-automated sharding with constraints
If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function.
If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.
For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:
@ -156,14 +160,16 @@ print(result)
This gives you a function with the particular output sharding you'd like.
## 3. Manual parallelism with `shard_map`
In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.experimental.shard_map.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.
## Manual parallelism with `shard_map`
`shard_map` works by mapping a function across a particular *mesh* of devices (`shard_map` maps over shards). In the example below:
In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices.
By contrast, with `shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.
- As before, {class}`jax.sharding.Mesh` allows for precise device placement, with the axis names parameter for logical and physical axis names.
- The `in_specs` argument determines the shard sizes. The `out_specs` argument identifies how the blocks are assembled back together.
`shard_map` works by mapping a function across a particular *mesh* of devices:
**Note:** {func}`jax.experimental.shard_map.shard_map` code can work inside {func}`jax.jit` if you need it.
```{code-cell}
:outputId: 435c32f3-557a-4676-c11b-17e6bab8c1e2
@ -182,7 +188,7 @@ arr = jnp.arange(32)
f_elementwise_sharded(arr)
```
The function you write only "sees" a single batch of the data, which we can see by printing the device local shape:
The function you write only "sees" a single batch of the data, which you can check by printing the device local shape:
```{code-cell}
:outputId: 99a3dc6e-154a-4ef6-8eaa-3dd0b68fb1da
@ -197,8 +203,9 @@ def f(x):
y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
```
Because each of your functions only sees the device-local part of the data, it means that aggregation-like functions require some extra thought.
For example, here's what a `shard_map` of a `sum` looks like:
Because each of your functions only "sees" the device-local part of the data, it means that aggregation-like functions require some extra thought.
For example, here's what a `shard_map` of a {func}`jax.numpy.sum` looks like:
```{code-cell}
:outputId: 1e9a45f5-5418-4246-c75b-f9bc6dcbbe72
@ -209,8 +216,9 @@ def f(x):
shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
```
Our function `f` operates separately on each shard, and the resulting summation reflects this.
If we want to sum across shards, we need to explicitly request it using collective operations like {func}`jax.lax.psum`:
Your function `f` operates separately on each shard, and the resulting summation reflects this.
If you want to sum across shards, you need to explicitly request it using collective operations like {func}`jax.lax.psum`:
```{code-cell}
:outputId: 4fd29e80-4fee-42b7-ff80-29f9887ab38d
@ -222,14 +230,13 @@ def f(x):
shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)
```
Because the output no longer has a sharded dimension, we set `out_specs=P()`.
Because the output no longer has a sharded dimension, set `out_specs=P()` (recall that the `out_specs` argument identifies how the blocks are assembled back together in `shard_map`).
## Comparing the three approaches
With these concepts fresh in our mind, let's compare the three approaches for a simple neural network layer.
We'll define our canonical function like this:
Start by defining your canonical function like this:
```{code-cell}
:id: 1TdhfTsoiqS1
@ -252,8 +259,9 @@ bias = rng.normal(size=(4,))
layer(x, weights, bias)
```
We can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data.
If we shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will autoatically happen in parallel:
You can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data.
If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel:
```{code-cell}
:outputId: 80be899e-8dbc-4bfc-acd2-0f3d554a0aa5
@ -268,7 +276,7 @@ weights_sharded = jax.device_put(weights, sharding)
layer(x_sharded, weights_sharded, bias)
```
Alternatively, we can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs:
Alternatively, you can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs:
```{code-cell}
:outputId: bb63e8da-ff4f-4e95-f083-10584882daf4
@ -282,7 +290,7 @@ def layer_auto(x, weights, bias):
layer_auto(x, weights, bias) # pass in unsharded inputs
```
Finally, we can do the same thing with `shard_map`, using `psum` to indicate the cross-shard collective required for the matrix product:
Finally, you can do the same thing with `shard_map`, using {func}`jax.lax.psum` to indicate the cross-shard collective required for the matrix product:
```{code-cell}
:outputId: 568d1c85-39a7-4dba-f09a-0e4f7c2ea918
@ -299,5 +307,10 @@ def layer_sharded(x, weights, bias):
layer_sharded(x, weights, bias)
```
This section has been a brief introduction of sharded and parallel computation;
for more discussion of `shard_map`, see {doc}`../notebooks/shard_map`.
## Next steps
This tutorial serves as a brief introduction of sharded and parallel computation in JAX.
To learn about each SPMD method in-depth, check out these docs:
- {doc}`../notebooks/Distributed_arrays_and_automatic_parallelization`
- {doc}`../notebooks/shard_map`