mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
DOC: add introduction to sharded computation
This commit is contained in:
parent
06cd05d1d6
commit
8e34da70f8
@ -126,6 +126,7 @@ exclude_patterns = [
|
||||
'jep/9407-type-promotion.md',
|
||||
'jax-101/*.md',
|
||||
'autodidax.md',
|
||||
'tutorials/sharded-computation.md',
|
||||
]
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
@ -213,6 +214,7 @@ nb_execution_excludepatterns = [
|
||||
# Requires accelerators
|
||||
'pallas/quickstart.*',
|
||||
'pallas/tpu/pipelining.*',
|
||||
'tutorials/sharded-computation.*'
|
||||
]
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
@ -26,7 +26,7 @@ JAX 101
|
||||
debugging
|
||||
random-numbers
|
||||
working-with-pytrees
|
||||
single-host-sharding
|
||||
sharded-computation
|
||||
stateful-computations
|
||||
simple-neural-network
|
||||
|
||||
|
@ -58,7 +58,7 @@ x.sharding
|
||||
|
||||
Here the array is on a single device, but in general a JAX array can be
|
||||
sharded across multiple devices, or even multiple hosts.
|
||||
To read more about sharded arrays and parallel computation, refer to {ref}`single-host-sharding`
|
||||
To read more about sharded arrays and parallel computation, refer to {ref}`sharded-computation`
|
||||
|
||||
(key-concepts-transformations)=
|
||||
## Transformations
|
||||
|
764
docs/tutorials/sharded-computation.ipynb
Normal file
764
docs/tutorials/sharded-computation.ipynb
Normal file
@ -0,0 +1,764 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"(sharded-computation)=\n",
|
||||
"# Introduction to sharded computation\n",
|
||||
"\n",
|
||||
"JAX's {class}`jax.Array` object is designed with distributed data and computation in mind.\n",
|
||||
"\n",
|
||||
"This section will cover three modes of parallel computation:\n",
|
||||
"\n",
|
||||
"- Automatic parallelism via {func}`jax.jit`, in which we let the compiler choose the optimal computation strategy\n",
|
||||
"- Semi-automatic parallelism using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`\n",
|
||||
"- Fully manual parallelism using {func}`jax.experimental.shard_map.shard_map`\n",
|
||||
"\n",
|
||||
"These examples will be run on Colab's free TPU runtime, which provides eight devices to work with:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"outputId": "18905ae4-7b5e-4bb9-acb4-d8ab914cb456"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n",
|
||||
" TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n",
|
||||
" TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n",
|
||||
" TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n",
|
||||
" TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n",
|
||||
" TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n",
|
||||
" TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n",
|
||||
" TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import jax\n",
|
||||
"jax.devices()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Key concept: data sharding\n",
|
||||
"\n",
|
||||
"Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.\n",
|
||||
"\n",
|
||||
"Each concrete {class}`jax.Array` object has a `sharding` attribute and a `devices()` method that can give you insight into how the underlying data are stored. In the simplest cases, arrays are sharded on a single device:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"outputId": "39fdbb79-d5c0-4ea6-8b20-88b2c502a27a"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import jax.numpy as jnp\n",
|
||||
"arr = jnp.arange(32.0).reshape(4, 8)\n",
|
||||
"arr.devices()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"outputId": "536f773a-7ef4-4526-c58b-ab4d486bf5a1"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"arr.sharding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"For a more visual representation of the storage layout, the {mod}`jax.debug` module provides some helpers to visualize the sharding of an array:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"outputId": "74a793e9-b13b-4d07-d8ec-7e25c547036d"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> TPU 0 </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span>\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"jax.debug.visualize_array_sharding(arr)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To create an array with a non-trivial sharding, we can define a `sharding` specification for the array and pass this to {func}`jax.device_put`.\n",
|
||||
"Here we'll define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimensional grid of devices with named axes:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"outputId": "0b397dba-3ddc-4aca-f002-2beab7e6b8a5"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Pardon the boilerplate; constructing a sharding will become easier soon!\n",
|
||||
"from jax.sharding import Mesh\n",
|
||||
"from jax.sharding import PartitionSpec\n",
|
||||
"from jax.sharding import NamedSharding\n",
|
||||
"from jax.experimental import mesh_utils\n",
|
||||
"\n",
|
||||
"P = jax.sharding.PartitionSpec\n",
|
||||
"devices = mesh_utils.create_device_mesh((2, 4))\n",
|
||||
"mesh = jax.sharding.Mesh(devices, P('x', 'y'))\n",
|
||||
"sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))\n",
|
||||
"print(sharding)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Passing this `sharding` to {func}`jax.device_put`, we obtain a sharded array:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"outputId": "c8ceedba-05ca-4156-e6e4-1e98bb664a66"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n",
|
||||
" [ 8. 9. 10. 11. 12. 13. 14. 15.]\n",
|
||||
" [16. 17. 18. 19. 20. 21. 22. 23.]\n",
|
||||
" [24. 25. 26. 27. 28. 29. 30. 31.]]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\"> </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\"> </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> TPU 0 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\"> TPU 1 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\"> TPU 2 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\"> TPU 3 </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\"> </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\"> </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\"> </span>\n",
|
||||
"<span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\"> </span>\n",
|
||||
"<span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\"> </span>\n",
|
||||
"<span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\"> TPU 6 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\"> TPU 7 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\"> TPU 4 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\"> TPU 5 </span>\n",
|
||||
"<span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\"> </span>\n",
|
||||
"<span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\"> </span>\n",
|
||||
"<span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\"> </span>\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n",
|
||||
"\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n",
|
||||
"\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n",
|
||||
"\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n",
|
||||
"\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n",
|
||||
"\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n",
|
||||
"\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"arr_sharded = jax.device_put(arr, sharding)\n",
|
||||
"\n",
|
||||
"print(arr_sharded)\n",
|
||||
"jax.debug.visualize_array_sharding(arr_sharded)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Automatic parallelism via `jit`\n",
|
||||
"Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a JIT-compiled function!\n",
|
||||
"The XLA compiler behind `jit` includes heuristics for optimizing computations across multiple devices.\n",
|
||||
"In the simplest of cases, those heuristics boil down to *computation follows data*.\n",
|
||||
"\n",
|
||||
"For example, here's a simple element-wise function: the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"outputId": "de46f86a-6907-49c8-f36c-ed835e78bc3d"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"shardings match: True\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@jax.jit\n",
|
||||
"def f_elementwise(x):\n",
|
||||
" return 2 * jnp.sin(x) + 1\n",
|
||||
"\n",
|
||||
"result = f_elementwise(arr_sharded)\n",
|
||||
"\n",
|
||||
"print(\"shardings match:\", result.sharding == arr_sharded.sharding)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data.\n",
|
||||
"Here we sum along the leading axis of `x`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"outputId": "90c3b997-3653-4a7b-c8ff-12a270f11d02"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> TPU 0,6 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\"> TPU 1,7 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a\"> TPU 2,4 </span><span style=\"color: #000000; text-decoration-color: #000000; background-color: #b5cf6b\"> TPU 3,5 </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a\"> </span><span style=\"color: #000000; text-decoration-color: #000000; background-color: #b5cf6b\"> </span>\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0,6\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 1,7\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mTPU 2,4\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mTPU 3,5\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[48. 52. 56. 60. 64. 68. 72. 76.]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@jax.jit\n",
|
||||
"def f_contract(x):\n",
|
||||
" return x.sum(axis=0)\n",
|
||||
"\n",
|
||||
"result = f_contract(arr_sharded)\n",
|
||||
"jax.debug.visualize_array_sharding(result)\n",
|
||||
"print(result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Semi-automated sharding with constraints\n",
|
||||
"\n",
|
||||
"If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function.\n",
|
||||
"\n",
|
||||
"For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"outputId": "8468f5c6-76ca-4367-c9f2-93c723687cfd"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> TPU 0 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\"> TPU 1 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\"> TPU 2 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\"> TPU 3 </span><span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\"> TPU 6 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\"> TPU 7 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\"> TPU 4 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\"> TPU 5 </span>\n",
|
||||
"<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\"> </span><span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\"> </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\"> </span>\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n",
|
||||
"\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[48. 52. 56. 60. 64. 68. 72. 76.]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@jax.jit\n",
|
||||
"def f_contract_2(x):\n",
|
||||
" out = x.sum(axis=0)\n",
|
||||
" # mesh = jax.create_mesh((8,), 'x')\n",
|
||||
" devices = mesh_utils.create_device_mesh(8)\n",
|
||||
" mesh = jax.sharding.Mesh(devices, P('x'))\n",
|
||||
" sharding = jax.sharding.NamedSharding(mesh, P('x'))\n",
|
||||
" return jax.lax.with_sharding_constraint(out, sharding)\n",
|
||||
"\n",
|
||||
"result = f_contract_2(arr_sharded)\n",
|
||||
"jax.debug.visualize_array_sharding(result)\n",
|
||||
"print(result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This gives you a function with the particular output sharding you'd like.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Manual parallelism with `shard_map`\n",
|
||||
"\n",
|
||||
"In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices.\n",
|
||||
"By contrast, with `shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.\n",
|
||||
"\n",
|
||||
"`shard_map` works by mapping a function across a particular *mesh* of devices:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"outputId": "435c32f3-557a-4676-c11b-17e6bab8c1e2"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array([ 1. , 2.682942 , 2.818595 , 1.28224 , -0.513605 ,\n",
|
||||
" -0.9178486 , 0.44116896, 2.3139732 , 2.9787164 , 1.824237 ,\n",
|
||||
" -0.08804226, -0.99998045, -0.07314599, 1.8403342 , 2.9812148 ,\n",
|
||||
" 2.3005757 , 0.42419332, -0.92279506, -0.50197446, 1.2997544 ,\n",
|
||||
" 2.8258905 , 2.6733112 , 0.98229736, -0.69244075, -0.81115675,\n",
|
||||
" 0.7352965 , 2.525117 , 2.912752 , 1.5418116 , -0.32726777,\n",
|
||||
" -0.97606325, 0.19192469], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from jax.experimental.shard_map import shard_map\n",
|
||||
"P = jax.sharding.PartitionSpec\n",
|
||||
"mesh = jax.sharding.Mesh(jax.devices(), 'x')\n",
|
||||
"\n",
|
||||
"f_elementwise_sharded = shard_map(\n",
|
||||
" f_elementwise,\n",
|
||||
" mesh=mesh,\n",
|
||||
" in_specs=P('x'),\n",
|
||||
" out_specs=P('x'))\n",
|
||||
"\n",
|
||||
"arr = jnp.arange(32)\n",
|
||||
"f_elementwise_sharded(arr)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The function you write only \"sees\" a single batch of the data, which we can see by printing the device local shape:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"outputId": "99a3dc6e-154a-4ef6-8eaa-3dd0b68fb1da"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"global shape: x.shape=(32,)\n",
|
||||
"device local shape: x.shape=(4,)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"x = jnp.arange(32)\n",
|
||||
"print(f\"global shape: {x.shape=}\")\n",
|
||||
"\n",
|
||||
"def f(x):\n",
|
||||
" print(f\"device local shape: {x.shape=}\")\n",
|
||||
" return x * 2\n",
|
||||
"\n",
|
||||
"y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Because each of your functions only sees the device-local part of the data, it means that aggregation-like functions require some extra thought.\n",
|
||||
"For example, here's what a `shard_map` of a `sum` looks like:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"outputId": "1e9a45f5-5418-4246-c75b-f9bc6dcbbe72"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array([ 6, 22, 38, 54, 70, 86, 102, 118], dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def f(x):\n",
|
||||
" return jnp.sum(x, keepdims=True)\n",
|
||||
"\n",
|
||||
"shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Our function `f` operates separately on each shard, and the resulting summation reflects this.\n",
|
||||
"If we want to sum across shards, we need to explicitly request it using collective operations like {func}`jax.lax.psum`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"outputId": "4fd29e80-4fee-42b7-ff80-29f9887ab38d"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array(496, dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def f(x):\n",
|
||||
" sum_in_shard = x.sum()\n",
|
||||
" return jax.lax.psum(sum_in_shard, 'x')\n",
|
||||
"\n",
|
||||
"shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Because the output no longer has a sharded dimension, we set `out_specs=P()`.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Comparing the three approaches\n",
|
||||
"\n",
|
||||
"With these concepts fresh in our mind, let's compare the three approaches for a simple neural network layer.\n",
|
||||
"We'll define our canonical function like this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"id": "1TdhfTsoiqS1"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@jax.jit\n",
|
||||
"def layer(x, weights, bias):\n",
|
||||
" return jax.nn.sigmoid(x @ weights + bias)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"outputId": "f3007fe4-f6f3-454e-e7c5-3638de484c0a"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array([0.02138912, 0.893112 , 0.59892005, 0.97742504], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"rng = np.random.default_rng(0)\n",
|
||||
"\n",
|
||||
"x = rng.normal(size=(32,))\n",
|
||||
"weights = rng.normal(size=(32, 4))\n",
|
||||
"bias = rng.normal(size=(4,))\n",
|
||||
"\n",
|
||||
"layer(x, weights, bias)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data.\n",
|
||||
"If we shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will autoatically happen in parallel:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"outputId": "80be899e-8dbc-4bfc-acd2-0f3d554a0aa5"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array([0.02138912, 0.893112 , 0.59892005, 0.97742504], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"P = jax.sharding.PartitionSpec\n",
|
||||
"mesh = jax.sharding.Mesh(jax.devices(), 'x')\n",
|
||||
"sharding = jax.sharding.NamedSharding(mesh, P('x'))\n",
|
||||
"\n",
|
||||
"x_sharded = jax.device_put(x, sharding)\n",
|
||||
"weights_sharded = jax.device_put(weights, sharding)\n",
|
||||
"\n",
|
||||
"layer(x_sharded, weights_sharded, bias)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Alternatively, we can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"outputId": "bb63e8da-ff4f-4e95-f083-10584882daf4"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@jax.jit\n",
|
||||
"def layer_auto(x, weights, bias):\n",
|
||||
" x = jax.lax.with_sharding_constraint(x, sharding)\n",
|
||||
" weights = jax.lax.with_sharding_constraint(weights, sharding)\n",
|
||||
" return layer(x, weights, bias)\n",
|
||||
"\n",
|
||||
"layer_auto(x, weights, bias) # pass in unsharded inputs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finally, we can do the same thing with `shard_map`, using `psum` to indicate the cross-shard collective required for the matrix product:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"outputId": "568d1c85-39a7-4dba-f09a-0e4f7c2ea918"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from functools import partial\n",
|
||||
"\n",
|
||||
"@jax.jit\n",
|
||||
"@partial(shard_map, mesh=mesh,\n",
|
||||
" in_specs=(P('x'), P('x', None), P(None)),\n",
|
||||
" out_specs=P(None))\n",
|
||||
"def layer_sharded(x, weights, bias):\n",
|
||||
" return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)\n",
|
||||
"\n",
|
||||
"layer_sharded(x, weights, bias)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This section has been a brief introduction of sharded and parallel computation;\n",
|
||||
"for more discussion of `shard_map`, see {doc}`../notebooks/shard_map`."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "TPU",
|
||||
"colab": {
|
||||
"gpuType": "V28",
|
||||
"provenance": [],
|
||||
"toc_visible": true
|
||||
},
|
||||
"jupytext": {
|
||||
"formats": "ipynb,md:myst"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
303
docs/tutorials/sharded-computation.md
Normal file
303
docs/tutorials/sharded-computation.md
Normal file
@ -0,0 +1,303 @@
|
||||
---
|
||||
jupytext:
|
||||
formats: ipynb,md:myst
|
||||
text_representation:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.1
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
---
|
||||
|
||||
(sharded-computation)=
|
||||
# Introduction to sharded computation
|
||||
|
||||
JAX's {class}`jax.Array` object is designed with distributed data and computation in mind.
|
||||
|
||||
This section will cover three modes of parallel computation:
|
||||
|
||||
- Automatic parallelism via {func}`jax.jit`, in which we let the compiler choose the optimal computation strategy
|
||||
- Semi-automatic parallelism using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`
|
||||
- Fully manual parallelism using {func}`jax.experimental.shard_map.shard_map`
|
||||
|
||||
These examples will be run on Colab's free TPU runtime, which provides eight devices to work with:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: 18905ae4-7b5e-4bb9-acb4-d8ab914cb456
|
||||
|
||||
import jax
|
||||
jax.devices()
|
||||
```
|
||||
|
||||
## Key concept: data sharding
|
||||
|
||||
Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.
|
||||
|
||||
Each concrete {class}`jax.Array` object has a `sharding` attribute and a `devices()` method that can give you insight into how the underlying data are stored. In the simplest cases, arrays are sharded on a single device:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: 39fdbb79-d5c0-4ea6-8b20-88b2c502a27a
|
||||
|
||||
import jax.numpy as jnp
|
||||
arr = jnp.arange(32.0).reshape(4, 8)
|
||||
arr.devices()
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:outputId: 536f773a-7ef4-4526-c58b-ab4d486bf5a1
|
||||
|
||||
arr.sharding
|
||||
```
|
||||
|
||||
For a more visual representation of the storage layout, the {mod}`jax.debug` module provides some helpers to visualize the sharding of an array:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: 74a793e9-b13b-4d07-d8ec-7e25c547036d
|
||||
|
||||
jax.debug.visualize_array_sharding(arr)
|
||||
```
|
||||
|
||||
To create an array with a non-trivial sharding, we can define a `sharding` specification for the array and pass this to {func}`jax.device_put`.
|
||||
Here we'll define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimensional grid of devices with named axes:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: 0b397dba-3ddc-4aca-f002-2beab7e6b8a5
|
||||
|
||||
# Pardon the boilerplate; constructing a sharding will become easier soon!
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.experimental import mesh_utils
|
||||
|
||||
P = jax.sharding.PartitionSpec
|
||||
devices = mesh_utils.create_device_mesh((2, 4))
|
||||
mesh = jax.sharding.Mesh(devices, P('x', 'y'))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
||||
print(sharding)
|
||||
```
|
||||
|
||||
Passing this `sharding` to {func}`jax.device_put`, we obtain a sharded array:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: c8ceedba-05ca-4156-e6e4-1e98bb664a66
|
||||
|
||||
arr_sharded = jax.device_put(arr, sharding)
|
||||
|
||||
print(arr_sharded)
|
||||
jax.debug.visualize_array_sharding(arr_sharded)
|
||||
```
|
||||
|
||||
The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.
|
||||
|
||||
|
||||
|
||||
## Automatic parallelism via `jit`
|
||||
Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a JIT-compiled function!
|
||||
The XLA compiler behind `jit` includes heuristics for optimizing computations across multiple devices.
|
||||
In the simplest of cases, those heuristics boil down to *computation follows data*.
|
||||
|
||||
For example, here's a simple element-wise function: the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: de46f86a-6907-49c8-f36c-ed835e78bc3d
|
||||
|
||||
@jax.jit
|
||||
def f_elementwise(x):
|
||||
return 2 * jnp.sin(x) + 1
|
||||
|
||||
result = f_elementwise(arr_sharded)
|
||||
|
||||
print("shardings match:", result.sharding == arr_sharded.sharding)
|
||||
```
|
||||
|
||||
As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data.
|
||||
Here we sum along the leading axis of `x`:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: 90c3b997-3653-4a7b-c8ff-12a270f11d02
|
||||
|
||||
@jax.jit
|
||||
def f_contract(x):
|
||||
return x.sum(axis=0)
|
||||
|
||||
result = f_contract(arr_sharded)
|
||||
jax.debug.visualize_array_sharding(result)
|
||||
print(result)
|
||||
```
|
||||
|
||||
The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.
|
||||
|
||||
|
||||
|
||||
## Semi-automated sharding with constraints
|
||||
|
||||
If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function.
|
||||
|
||||
For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: 8468f5c6-76ca-4367-c9f2-93c723687cfd
|
||||
|
||||
@jax.jit
|
||||
def f_contract_2(x):
|
||||
out = x.sum(axis=0)
|
||||
# mesh = jax.create_mesh((8,), 'x')
|
||||
devices = mesh_utils.create_device_mesh(8)
|
||||
mesh = jax.sharding.Mesh(devices, P('x'))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
return jax.lax.with_sharding_constraint(out, sharding)
|
||||
|
||||
result = f_contract_2(arr_sharded)
|
||||
jax.debug.visualize_array_sharding(result)
|
||||
print(result)
|
||||
```
|
||||
|
||||
This gives you a function with the particular output sharding you'd like.
|
||||
|
||||
|
||||
|
||||
## Manual parallelism with `shard_map`
|
||||
|
||||
In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices.
|
||||
By contrast, with `shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.
|
||||
|
||||
`shard_map` works by mapping a function across a particular *mesh* of devices:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: 435c32f3-557a-4676-c11b-17e6bab8c1e2
|
||||
|
||||
from jax.experimental.shard_map import shard_map
|
||||
P = jax.sharding.PartitionSpec
|
||||
mesh = jax.sharding.Mesh(jax.devices(), 'x')
|
||||
|
||||
f_elementwise_sharded = shard_map(
|
||||
f_elementwise,
|
||||
mesh=mesh,
|
||||
in_specs=P('x'),
|
||||
out_specs=P('x'))
|
||||
|
||||
arr = jnp.arange(32)
|
||||
f_elementwise_sharded(arr)
|
||||
```
|
||||
|
||||
The function you write only "sees" a single batch of the data, which we can see by printing the device local shape:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: 99a3dc6e-154a-4ef6-8eaa-3dd0b68fb1da
|
||||
|
||||
x = jnp.arange(32)
|
||||
print(f"global shape: {x.shape=}")
|
||||
|
||||
def f(x):
|
||||
print(f"device local shape: {x.shape=}")
|
||||
return x * 2
|
||||
|
||||
y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
|
||||
```
|
||||
|
||||
Because each of your functions only sees the device-local part of the data, it means that aggregation-like functions require some extra thought.
|
||||
For example, here's what a `shard_map` of a `sum` looks like:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: 1e9a45f5-5418-4246-c75b-f9bc6dcbbe72
|
||||
|
||||
def f(x):
|
||||
return jnp.sum(x, keepdims=True)
|
||||
|
||||
shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
|
||||
```
|
||||
|
||||
Our function `f` operates separately on each shard, and the resulting summation reflects this.
|
||||
If we want to sum across shards, we need to explicitly request it using collective operations like {func}`jax.lax.psum`:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: 4fd29e80-4fee-42b7-ff80-29f9887ab38d
|
||||
|
||||
def f(x):
|
||||
sum_in_shard = x.sum()
|
||||
return jax.lax.psum(sum_in_shard, 'x')
|
||||
|
||||
shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)
|
||||
```
|
||||
|
||||
Because the output no longer has a sharded dimension, we set `out_specs=P()`.
|
||||
|
||||
|
||||
|
||||
## Comparing the three approaches
|
||||
|
||||
With these concepts fresh in our mind, let's compare the three approaches for a simple neural network layer.
|
||||
We'll define our canonical function like this:
|
||||
|
||||
```{code-cell}
|
||||
:id: 1TdhfTsoiqS1
|
||||
|
||||
@jax.jit
|
||||
def layer(x, weights, bias):
|
||||
return jax.nn.sigmoid(x @ weights + bias)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:outputId: f3007fe4-f6f3-454e-e7c5-3638de484c0a
|
||||
|
||||
import numpy as np
|
||||
rng = np.random.default_rng(0)
|
||||
|
||||
x = rng.normal(size=(32,))
|
||||
weights = rng.normal(size=(32, 4))
|
||||
bias = rng.normal(size=(4,))
|
||||
|
||||
layer(x, weights, bias)
|
||||
```
|
||||
|
||||
We can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data.
|
||||
If we shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will autoatically happen in parallel:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: 80be899e-8dbc-4bfc-acd2-0f3d554a0aa5
|
||||
|
||||
P = jax.sharding.PartitionSpec
|
||||
mesh = jax.sharding.Mesh(jax.devices(), 'x')
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
|
||||
x_sharded = jax.device_put(x, sharding)
|
||||
weights_sharded = jax.device_put(weights, sharding)
|
||||
|
||||
layer(x_sharded, weights_sharded, bias)
|
||||
```
|
||||
|
||||
Alternatively, we can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: bb63e8da-ff4f-4e95-f083-10584882daf4
|
||||
|
||||
@jax.jit
|
||||
def layer_auto(x, weights, bias):
|
||||
x = jax.lax.with_sharding_constraint(x, sharding)
|
||||
weights = jax.lax.with_sharding_constraint(weights, sharding)
|
||||
return layer(x, weights, bias)
|
||||
|
||||
layer_auto(x, weights, bias) # pass in unsharded inputs
|
||||
```
|
||||
|
||||
Finally, we can do the same thing with `shard_map`, using `psum` to indicate the cross-shard collective required for the matrix product:
|
||||
|
||||
```{code-cell}
|
||||
:outputId: 568d1c85-39a7-4dba-f09a-0e4f7c2ea918
|
||||
|
||||
from functools import partial
|
||||
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh,
|
||||
in_specs=(P('x'), P('x', None), P(None)),
|
||||
out_specs=P(None))
|
||||
def layer_sharded(x, weights, bias):
|
||||
return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)
|
||||
|
||||
layer_sharded(x, weights, bias)
|
||||
```
|
||||
|
||||
This section has been a brief introduction of sharded and parallel computation;
|
||||
for more discussion of `shard_map`, see {doc}`../notebooks/shard_map`.
|
@ -1,6 +0,0 @@
|
||||
(single-host-sharding)=
|
||||
# Sharded data on a single host
|
||||
|
||||
```{note}
|
||||
This is a placeholder for a section in the new {ref}`jax-tutorials`.
|
||||
```
|
@ -134,7 +134,7 @@ x.sharding
|
||||
|
||||
In this case the sharding is on a single device, but in general a JAX array can be
|
||||
sharded across multiple devices, or even multiple hosts.
|
||||
To read more about sharded arrays and parallel computation, refer to {ref}`single-host-sharding`
|
||||
To read more about sharded arrays and parallel computation, refer to {ref}`sharded-computation`
|
||||
|
||||
(thinking-in-jax-pytrees)=
|
||||
## Pytrees
|
||||
|
Loading…
x
Reference in New Issue
Block a user