diff --git a/CHANGELOG.md b/CHANGELOG.md index d1a115749..78e893dd5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,8 +18,7 @@ Remember to align the itemized text with the first line of an item within a list breaking change to the `pjit` API. The [jax.Array migration guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html) can help you migrate your codebase to `jax.Array`. You can also look at the - [Parallelism with - JAX](https://jax.readthedocs.io/en/latest/notebooks/Parallelism_with_JAX.html) + [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial to understand the new concepts. * `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`. diff --git a/docs/conf.py b/docs/conf.py index 0cb75b3be..403b671be 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -210,7 +210,7 @@ nb_execution_excludepatterns = [ # TODO(jakevdp): enable execution on the following if possible: 'jax-101/*', 'notebooks/xmap_tutorial.*', - 'notebooks/Parallelism_with_JAX.*', + 'notebooks/Distributed_arrays_and_automatic_parallelization.*', ] # -- Options for HTMLHelp output --------------------------------------------- diff --git a/docs/index.rst b/docs/index.rst index 12a25c4c4..13f36aa2d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,7 +7,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more. .. note:: JAX 0.4.0 introduces new parallelism APIs, including breaking changes to :func:`jax.experimental.pjit` and a new unified ``jax.Array`` type. - Please see `Parallelism with JAX `_ tutorial and the :ref:`jax-array-migration` + Please see `Distributed arrays and automatic parallelization `_ tutorial and the :ref:`jax-array-migration` guide for more information. .. toctree:: @@ -52,7 +52,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more. notebooks/autodiff_cookbook multi_process - notebooks/Parallelism_with_JAX + notebooks/Distributed_arrays_and_automatic_parallelization notebooks/vmapped_log_probs notebooks/neural_network_with_tfds_data notebooks/Custom_derivative_rules_for_Python_code diff --git a/docs/jax_array_migration.md b/docs/jax_array_migration.md index 682fe4f03..1fa91b60d 100644 --- a/docs/jax_array_migration.md +++ b/docs/jax_array_migration.md @@ -18,7 +18,7 @@ the unified jax.Array After the migration is complete `jax.Array` will be the only type of array in JAX. -This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Parallelism with JAX](https://jax.readthedocs.io/en/latest/notebooks/Parallelism_with_JAX.html) tutorial. +This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial. ### How to enable jax.Array? diff --git a/docs/notebooks/Parallelism_with_JAX.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb similarity index 84% rename from docs/notebooks/Parallelism_with_JAX.ipynb rename to docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 77601b149..9db01bdc7 100644 --- a/docs/notebooks/Parallelism_with_JAX.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -6,7 +6,7 @@ "id": "PxHrg4Cjuapm" }, "source": [ - "# Parallelism with JAX" + "# Distributed arrays and automatic parallelization" ] }, { @@ -17,17 +17,15 @@ "source": [ "**This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.0 and newer.**\n", "\n", - "See {ref}`jax-array-migration` guide for migrating existing pre-v0.4.0 codebases to `jax.Array`.\n", + "See [`jax-array-migration`](https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration) guide for migrating existing pre-v0.4.0 codebases to `jax.Array`.\n", "\n", - "**The features required by `jax.Array` are not supported by the Colab TPU runtime at this time.**" + "**The features required by `jax.Array` are not supported by the Colab TPU runtime at this time, but are available on Cloud TPU.**" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "41dde63b" - }, + "execution_count": 1, + "metadata": {}, "outputs": [], "source": [ "import os\n", @@ -49,7 +47,17 @@ "id": "eyHMwyEfQJcz" }, "source": [ - "⚠️ WARNING: The notebook requires multiple devices in order to run correctly." + "⚠️ WARNING: notebook requires 8 devices to run." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "if len(jax.local_devices()) < 8:\n", + " raise Exception(\"Notebook requires 8 devices to run\")" ] }, { @@ -58,116 +66,41 @@ "id": "3f37ca93" }, "source": [ - "## `Sharding` describes how array values are laid out in memory across devices" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "W6HsXauGxL6w" - }, - "source": [ - "### Sharding basics, and the `PositionalSharding` subclass" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NWDyp_EjVHkg" - }, - "source": [ - "To parallelize computation across multiple devices, we first have to describe how the data that computation acts on can be distributed across multiple devices. That means describing distributed layouts conveniently.\n", + "## Intro and a quick example\n", "\n", - "In JAX, `Sharding` objects describe distributed memory layouts. They can be used with `jax.device_put` to produce a value with distributed layout.\n", + "By reading this tutorial notebook, you'll learn about `jax.Array`, a unified\n", + "datatype for representing arrays, even with physical storage spanning multiple\n", + "devices. You'll also learn about how using `jax.Array`s together with `jax.jit`\n", + "can provide automatic compiler-based parallelization.\n", "\n", - "For example, here's a value with a single-device `Sharding`:" + "Before we think step by step, here's a quick example.\n", + "First, we'll create a `jax.Array` sharded across multiple devices:" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "s5jXIod7VcWW" - }, - "outputs": [], - "source": [ - "import jax\n", - "x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HhCjhK0zXIqX" - }, - "source": [ - "Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring!\n", - "\n", - "But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8fc925d2", - "outputId": "40f044de-0a39-46cb-9e61-2277648a9c0e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ TPU 0 │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "└───────────────────────┘\n" - ] - } - ], - "source": [ - "jax.debug.visualize_array_sharding(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xKC-WWgc8QGo" - }, - "source": [ - "A quick example of what `jax.Array` can do before we dive into more details:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kiZ59Mho5lzk" - }, + "execution_count": 3, + "metadata": {}, "outputs": [], "source": [ "from jax.experimental import mesh_utils\n", - "from jax.sharding import PositionalSharding\n", - "\n", - "# Let's create a Sharding object which we will use to distribute\n", - "# a value across devices.\n", + "from jax.sharding import PositionalSharding" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a Sharding object to distribute a value across devices:\n", "sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "t_Mw8cxU7YSK", - "outputId": "b44f6bea-c195-4b9a-d872-815112e8ef2e" - }, + "execution_count": 5, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -186,6 +119,9 @@ } ], "source": [ + "# Create an array of random values:\n", + "x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n", + "# and use jax.device_put to distribute it across devices:\n", "y = jax.device_put(x, sharding.reshape(4, 2))\n", "jax.debug.visualize_array_sharding(y)" ] @@ -196,16 +132,14 @@ "id": "jZ0ZY9Um9Jg4" }, "source": [ - "We distributed `x` which was on a single device before to 8 devices." + "Next, we'll apply a computation to it and visualize how the result values are\n", + "stored across multiple devices too:" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WUzDakoG7a2l", - "outputId": "b575be88-cbb1-4081-fa98-1ba5c6421921" - }, + "execution_count": 6, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -234,23 +168,21 @@ "id": "5qccVQoE9tEi" }, "source": [ - "After doing an `jnp.sin` operation on a distributed `jax.Array`, the sharding on the `output` was preserved. Also the operation itself happened on multiple devices. To test that, we can do a timing experiment:" + "The evaluation of the `jnp.sin` application was automatically parallelized\n", + "across the devices on which the input values (and output values) are stored:" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "RdfmbYxr-2I_", - "outputId": "50e659d4-9ca0-499e-bd49-66fe59384675" - }, + "execution_count": 7, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The slowest run took 14.70 times longer than the fastest. This could mean that an intermediate result is being cached \n", - "5 loops, best of 5: 15.2 ms per loop\n" + "The slowest run took 13.32 times longer than the fastest. This could mean that an intermediate result is being cached \n", + "5 loops, best of 5: 9.69 ms per loop\n" ] } ], @@ -261,17 +193,14 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cNLJQY_r-W75", - "outputId": "da362eae-810a-438a-81e3-66c16698a456" - }, + "execution_count": 8, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5 loops, best of 5: 2.58 ms per loop\n" + "5 loops, best of 5: 1.86 ms per loop\n" ] } ], @@ -286,15 +215,86 @@ "id": "xWknFQbQ-bzV" }, "source": [ - "We can now continue with the rest of the details of `Sharding`s:" + "Now let's look at each of these pieces in more detail!\n", + "\n", + "\n", + "## `Sharding` describes how array values are laid out in memory across devices" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "W6HsXauGxL6w" + }, + "source": [ + "### Sharding basics, and the `PositionalSharding` subclass" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NWDyp_EjVHkg" + }, + "source": [ + "To parallelize computation across multiple devices, we first must lay out input data across multiple devices.\n", + "\n", + "In JAX, `Sharding` objects describe distributed memory layouts. They can be used with `jax.device_put` to produce a value with distributed layout.\n", + "\n", + "For example, here's a value with a single-device `Sharding`:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "┌───────────────────────┐\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ TPU 0 │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "└───────────────────────┘\n" + ] + } + ], + "source": [ + "jax.debug.visualize_array_sharding(x)" + ] + }, + { + "cell_type": "markdown", "metadata": { - "id": "f4d7c00a" + "id": "HhCjhK0zXIqX" }, + "source": [ + "Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring!\n", + "\n", + "But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, "outputs": [], "source": [ "from jax.experimental import mesh_utils\n", @@ -312,11 +312,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "K2PL4LwBX0JE", - "outputId": "1c3bbe5e-3377-49a4-a8f0-57e5f224a535" - }, + "execution_count": 12, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -362,11 +359,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "d6fd0d23", - "outputId": "2eeea24d-553d-4049-82ba-a431a08f5ac8" - }, + "execution_count": 13, + "metadata": {}, "outputs": [ { "data": { @@ -374,7 +368,7 @@ "PositionalSharding([{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}])" ] }, - "execution_count": 11, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -394,11 +388,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "b5445d3b", - "outputId": "16630351-cb09-42c9-92e5-fe53bc5399d5" - }, + "execution_count": 14, + "metadata": {}, "outputs": [ { "data": { @@ -413,7 +404,7 @@ " [{TPU 5}]])" ] }, - "execution_count": 12, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -424,11 +415,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pS7xTZeBm6Dt", - "outputId": "1d799888-7aed-4415-9812-234059321797" - }, + "execution_count": 15, + "metadata": {}, "outputs": [ { "data": { @@ -439,7 +427,7 @@ " [{TPU 4} {TPU 5}]])" ] }, - "execution_count": 13, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -466,11 +454,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "6JhLL3i_sPth", - "outputId": "9cc7ee3e-b15c-435a-8f7d-cb127bf0ce07" - }, + "execution_count": 16, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -490,11 +475,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "5FCqZfhWt88c", - "outputId": "e9abbf02-2a9c-4d34-94e0-8d73cfdc96fa" - }, + "execution_count": 17, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -525,16 +507,13 @@ "source": [ "Here `y` represents the same _value_ as `x`, but its shards (i.e. slices) are stored in different devices' memories.\n", "\n", - "Different `sharding` shapes result in different distributed layouts (i.e. shardings) of the result:" + "Different `PositionalSharding` shapes result in different distributed layouts (i.e. shardings) of the result:" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "nt4IbVMkswlO", - "outputId": "ce106d8a-8ddf-4129-bf5d-1982a28f9160" - }, + "execution_count": 18, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -551,11 +530,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "AyZzDpnFuIpz", - "outputId": "7cf42e32-4da6-4693-f44e-eae9ec37f203" - }, + "execution_count": 19, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -593,11 +569,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "l5t_Mg_Rux6j", - "outputId": "68ae315b-7040-4753-f803-514d88a20333" - }, + "execution_count": 20, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -614,11 +587,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Gi3sDdqAu_8W", - "outputId": "3d5b3d72-d68c-481e-854a-b3c2bbea85b6" - }, + "execution_count": 21, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -651,16 +621,13 @@ "source": [ "Here the visualization shows that `x` is sharded two ways along its second dimension (and not sharded along the first dimension), and each of those shards is replicated four ways (i.e. stored in four device memories).\n", "\n", - "The `replicate` method acts similar to the familiar NumPy array reduction methods like `.sum()` and `.prod()`. It operates along an axis performing a set union. So if `sharding` has shape `(4, 2)`, then `sharding.replicate(0, keepdims=True)` has shape `(1, 2)`, and `sharding.replicate(1, keepdims=True)` has shape `(4, 1)`. Unlike analogous NumPy methods, `keepdims=True` is actually the default, so reduced-over axes aren't squeezed:" + "The `replicate` method is analogous to the familiar NumPy array reduction methods like `.sum()` and `.prod()`. It operates along an axis performing a set union. So if `sharding` has shape `(4, 2)`, then `sharding.replicate(0, keepdims=True)` has shape `(1, 2)`, and `sharding.replicate(1, keepdims=True)` has shape `(4, 1)`. Unlike analogous NumPy methods, `keepdims=True` is actually the default, so reduced-over axes aren't squeezed:" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "vDlU8hgJvson", - "outputId": "472ddce5-cdec-4434-9a59-995aa1d3cf17" - }, + "execution_count": 22, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -678,11 +645,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "vHWC4womxCdf", - "outputId": "e5f9a04a-42c1-4c0d-fafa-66c57cd6f774" - }, + "execution_count": 23, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -727,11 +691,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bQCdEAHQ1q8J", - "outputId": "41dcae82-22cd-47cc-8f7c-cf0031c1e667" - }, + "execution_count": 24, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -773,10 +734,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2sDUx-VbzvVz" - }, + "execution_count": 25, + "metadata": {}, "outputs": [], "source": [ "devices = mesh_utils.create_device_mesh((4, 2))\n", @@ -784,7 +743,7 @@ "\n", "def mesh_sharding(\n", " pspec: PartitionSpec, mesh: Optional[maps.Mesh] = None,\n", - ") -> NamedSharding:\n", + " ) -> NamedSharding:\n", " if mesh is None:\n", " mesh = default_mesh\n", " return NamedSharding(mesh, pspec)" @@ -792,11 +751,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KirNGYXLzvK6", - "outputId": "777c9b33-a19c-414d-8e52-df3338bc0411" - }, + "execution_count": 26, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -830,11 +786,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JJaKU2pJ2eAC", - "outputId": "ceb3f087-3dbe-4b51-a024-b3af97f12a20" - }, + "execution_count": 27, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -861,11 +814,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "nLlRM7DZ25-K", - "outputId": "c844464e-fec5-4161-ea6d-731fa0f43740" - }, + "execution_count": 28, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -884,7 +834,9 @@ } ], "source": [ - "# `None` means that `x` is replicated on the 1st dimension (counting from 0).\n", + "# This `None` means that `x` is not sharded on its second dimension,\n", + "# and since the Mesh axis name 'b' is not mentioned, shards are\n", + "# replicated across it.\n", "y = jax.device_put(x, mesh_sharding(P('a', None)))\n", "jax.debug.visualize_array_sharding(y)" ] @@ -902,11 +854,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "svq_HGHU29HV", - "outputId": "318302c1-c865-4e7c-db05-26fc3a435e78" - }, + "execution_count": 29, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -933,11 +882,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "oRhRKETlX0RD", - "outputId": "966287ce-de6f-4273-fe70-3f4ab5a74bfd" - }, + "execution_count": 30, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -973,11 +919,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ldq5Ws2A3Bbl", - "outputId": "bd8d18e6-81c0-40b3-b9ae-e1315877d819" - }, + "execution_count": 31, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1032,17 +975,15 @@ "id": "JukoaRhl4tXJ" }, "source": [ - "With sharded data, the compiler can give us parallel computation. In particular, functions decorated with `jax.jit` can operate over sharded arrays without copying data onto a single device. Instead, computation follows sharding: based on the sharding of the input data, the compiler decides shardings for intermediates and output values, and parallelizes their evaluation, even inserting communication operations as necessary.\n", + "With sharded input data, the compiler can give us parallel computation. In particular, functions decorated with `jax.jit` can operate over sharded arrays without copying data onto a single device. Instead, computation follows sharding: based on the sharding of the input data, the compiler decides shardings for intermediates and output values, and parallelizes their evaluation, even inserting communication operations as necessary.\n", "\n", "For example, the simplest computation is an elementwise one:" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "_NqZnEUHgZQv" - }, + "execution_count": 32, + "metadata": {}, "outputs": [], "source": [ "from jax.experimental import mesh_utils\n", @@ -1052,11 +993,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "x89raigTazVJ", - "outputId": "7d59f9fe-6509-416e-ebf3-9bad49146435" - }, + "execution_count": 33, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1110,11 +1048,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "D52tW3y-cx32", - "outputId": "dd895028-1a70-485b-ef46-7accb43b261b" - }, + "execution_count": 34, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1181,11 +1116,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BUcN-RqtfRml", - "outputId": "67e49f89-f76a-4e38-a231-00c76886a90b" - }, + "execution_count": 35, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1212,11 +1144,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "iKrmBxJ-fhM9", - "outputId": "743ebbfe-ef93-4cd5-e030-694e7018f781" - }, + "execution_count": 36, + "metadata": {}, "outputs": [ { "data": { @@ -1224,7 +1153,7 @@ "True" ] }, - "execution_count": 34, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -1236,11 +1165,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gpcGJ1PSfSAV", - "outputId": "463083fd-2e19-4bf8-fb4b-3e5f42934b91" - }, + "execution_count": 37, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1256,17 +1182,14 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "1LMWZuYRfSGT", - "outputId": "5c98dc6e-9ddf-4176-8b82-fd2d16517ad2" - }, + "execution_count": 38, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5 loops, best of 5: 3.26 ms per loop\n" + "5 loops, best of 5: 3.25 ms per loop\n" ] } ], @@ -1285,11 +1208,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sdhFK3VGntbc", - "outputId": "662e6811-354b-4d26-be10-eb65ffdf21fa" - }, + "execution_count": 39, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1335,10 +1255,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "9xmq1Jbatxwz" - }, + "execution_count": 40, + "metadata": {}, "outputs": [], "source": [ "import textwrap\n", @@ -1351,11 +1269,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Yah71IjBqyKD", - "outputId": "fd28bfaa-bf1c-4236-9c9b-fbf32ca3a8b9" - }, + "execution_count": 41, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1379,11 +1294,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "HSHDAuJDqyO3", - "outputId": "a2b76cea-6944-4323-873b-1ba4ced9baa8" - }, + "execution_count": 42, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1425,11 +1337,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hDa3ogiwrvZx", - "outputId": "f2c186ee-4626-4352-f9af-794c042e2b89" - }, + "execution_count": 43, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1466,21 +1375,18 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "_acYTsBsxpyP" - }, + "execution_count": 44, + "metadata": {}, "outputs": [], "source": [ + "# TODO(mattjj,yashkatariya): remove cell when with_sharding_constraint is in jax.lax\n", "jax.lax.with_sharding_constraint = jax.experimental.pjit.with_sharding_constraint" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KIi13NFHxz77" - }, + "execution_count": 45, + "metadata": {}, "outputs": [], "source": [ "sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))" @@ -1488,10 +1394,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "UTRs-Zf2x8oJ" - }, + "execution_count": 46, + "metadata": {}, "outputs": [], "source": [ "x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n", @@ -1500,10 +1404,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "fkm3YKfZwkWt" - }, + "execution_count": 47, + "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", @@ -1515,11 +1417,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "tIglE_fayQqw", - "outputId": "ebcf2cb5-e67f-4d53-cc29-10212a88e127" - }, + "execution_count": 48, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1556,10 +1455,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DzuzKQhG2QLE" - }, + "execution_count": 49, + "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", @@ -1571,11 +1468,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "lmDxCQ1W2TlD", - "outputId": "7664dabd-9ba2-4d53-a7da-55ead835c3f5" - }, + "execution_count": 50, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1636,9 +1530,7 @@ "id": "g7y0OJBSGoSW" }, "source": [ - "**⚠️ WARNING: The following is meant to be a demonstration of automatic sharding propagation with `jax.Array`, but it is _not a recommended practice_.**\n", - "\n", - "For neural network training, it is a good idea to constraint parameter and input sharding to be the same at every training step which can be achieved e.g. using `pjit`'s `out_axis_resources` parameter and `with_sharding_constraint`." + "**⚠️ WARNING: The following is meant to be a simple demonstration of automatic sharding propagation with `jax.Array`, but it may not reflect best practices for real examples.** For instance, real examples may require more use of `with_sharding_constraint`." ] }, { @@ -1652,10 +1544,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sDAeJoNp_VyP" - }, + "execution_count": 51, + "metadata": {}, "outputs": [], "source": [ "import jax\n", @@ -1664,10 +1554,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "t-J6YtpA2db0" - }, + "execution_count": 52, + "metadata": {}, "outputs": [], "source": [ "def predict(params, inputs):\n", @@ -1684,10 +1572,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4USnNl6w4Y1K" - }, + "execution_count": 53, + "metadata": {}, "outputs": [], "source": [ "loss_jit = jax.jit(loss)\n", @@ -1696,10 +1582,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "nfqG0N1g2dhk" - }, + "execution_count": 54, + "metadata": {}, "outputs": [], "source": [ "def init_layer(key, n_in, n_out):\n", @@ -1735,10 +1619,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "uxZ4Czqyzrc5" - }, + "execution_count": 55, + "metadata": {}, "outputs": [], "source": [ "sharding = PositionalSharding(jax.devices()).reshape(8, 1)" @@ -1746,10 +1628,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "q9maIR6K4T9r" - }, + "execution_count": 56, + "metadata": {}, "outputs": [], "source": [ "batch = jax.device_put(batch, sharding)\n", @@ -1758,11 +1638,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "CtKIMM6ry7Ov", - "outputId": "336c1892-a3e1-4a2a-a0ee-a16dd7a24103" - }, + "execution_count": 57, + "metadata": {}, "outputs": [ { "data": { @@ -1770,7 +1647,7 @@ "Array(23.469475, dtype=float32)" ] }, - "execution_count": 55, + "execution_count": 57, "metadata": {}, "output_type": "execute_result" } @@ -1781,11 +1658,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "tAM6NQkly8lw", - "outputId": "8e5caf04-3c38-426b-8e8a-823760d6688c" - }, + "execution_count": 58, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1808,17 +1682,14 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Eix05eVQy-LZ", - "outputId": "8f0d79b4-9ff3-47a9-d25f-ecc770e72195" - }, + "execution_count": 59, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5 loops, best of 5: 26.2 ms per loop\n" + "5 loops, best of 5: 26.3 ms per loop\n" ] } ], @@ -1828,10 +1699,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "W-19ajlSy_gF" - }, + "execution_count": 60, + "metadata": {}, "outputs": [], "source": [ "batch_single = jax.device_put(batch, jax.devices()[0])\n", @@ -1840,11 +1709,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DBHfeKyUzBD9", - "outputId": "a96ee2c8-179d-4dfa-aa72-5c0076ff6881" - }, + "execution_count": 61, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1869,10 +1735,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gw1WZyXu4owx" - }, + "execution_count": 62, + "metadata": {}, "outputs": [], "source": [ "sharding = sharding.reshape(4, 2)" @@ -1880,11 +1744,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "P0s_ibu8z0hW", - "outputId": "c9bafdc7-a811-4a38-db6b-89d55355a241" - }, + "execution_count": 63, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1919,10 +1780,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "7kNJVPBjz5nq" - }, + "execution_count": 64, + "metadata": {}, "outputs": [], "source": [ "(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params\n", @@ -1944,11 +1803,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "I8ZJiiGb0HJk", - "outputId": "5c406158-6b71-45ec-bdad-2d6b39b237ca" - }, + "execution_count": 65, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1974,11 +1830,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "t2fsJ_Ow0LgK", - "outputId": "50ad8513-cb50-44d3-e0df-551ddd3c9fcb" - }, + "execution_count": 66, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -2004,11 +1857,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "xnNgGB7-0Nh4", - "outputId": "338d8bd4-0e0d-4a4a-99f7-e408b95ab274" - }, + "execution_count": 67, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -2024,10 +1874,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ygV3-IBV0Qx3" - }, + "execution_count": 68, + "metadata": {}, "outputs": [], "source": [ "step_size = 1e-5\n", @@ -2040,11 +1888,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "VWXN24Xh0Tkc", - "outputId": "13379078-3863-42f3-e5f4-368095c3a35c" - }, + "execution_count": 69, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -2060,11 +1905,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Cq3TzYU70Vfd", - "outputId": "9aae2051-3e4f-4918-c221-81d21edb563a" - }, + "execution_count": 70, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -2103,17 +1945,14 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hAeLBs9D0Z8T", - "outputId": "3eeacdd1-b44b-46e7-9868-dbdbca4a78e2" - }, + "execution_count": 71, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "10 loops, best of 10: 30.6 ms per loop\n" + "10 loops, best of 10: 30.5 ms per loop\n" ] } ], @@ -2156,10 +1995,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "quUwyGoiHub2" - }, + "execution_count": 72, + "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", @@ -2183,11 +2020,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "nxPHw0loLgFd", - "outputId": "59514362-61c5-4181-c7e9-ad7777218779" - }, + "execution_count": 73, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -2214,11 +2048,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "uf6hTkjvKifH", - "outputId": "3af4c656-0753-4dd5-e4be-0097bba43256" - }, + "execution_count": 74, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -2244,11 +2075,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "G87r_Aq6Ts_F", - "outputId": "5c8cd729-5f2a-450d-ee40-fb788f91707a" - }, + "execution_count": 75, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -2275,11 +2103,8 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8RplTPyRSTbW", - "outputId": "74e2ef51-c5c0-4c25-cc90-2f465961bdae" - }, + "execution_count": 76, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -2301,16 +2126,13 @@ "id": "kaK--hPmSPpV" }, "source": [ - "One caveat, however, is that _the random values produced may be different than before_, even though they were generated by the same random key:" + "One caveat to the `jax_threefry_partitionable` option, however, is that _the random values produced may be different than without the flag set_, even though they were generated by the same random key:" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "f_EjYjOpSO18", - "outputId": "b50278f2-927d-4aea-fb04-cdd5c2fc1a66" - }, + "execution_count": 77, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -2353,6 +2175,10 @@ ], "metadata": { "colab": { + "last_runtime": { + "build_target": "//learning/multipod/pax/tools:colab_notebook", + "kind": "private" + }, "provenance": [], "toc_visible": true }, diff --git a/docs/notebooks/Parallelism_with_JAX.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md similarity index 77% rename from docs/notebooks/Parallelism_with_JAX.md rename to docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index a34e09c1a..34d13613a 100644 --- a/docs/notebooks/Parallelism_with_JAX.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -13,19 +13,17 @@ kernelspec: +++ {"id": "PxHrg4Cjuapm"} -# Parallelism with JAX +# Distributed arrays and automatic parallelization +++ {"id": "pFtQjv4SzHRj"} **This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.0 and newer.** -See {ref}`jax-array-migration` guide for migrating existing pre-v0.4.0 codebases to `jax.Array`. +See [`jax-array-migration`](https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration) guide for migrating existing pre-v0.4.0 codebases to `jax.Array`. -**The features required by `jax.Array` are not supported by the Colab TPU runtime at this time.** +**The features required by `jax.Array` are not supported by the Colab TPU runtime at this time, but are available on Cloud TPU.** ```{code-cell} -:id: 41dde63b - import os import functools @@ -41,10 +39,73 @@ jax.config.update('jax_array', True) +++ {"id": "eyHMwyEfQJcz"} -⚠️ WARNING: The notebook requires multiple devices in order to run correctly. +⚠️ WARNING: notebook requires 8 devices to run. + +```{code-cell} +if len(jax.local_devices()) < 8: + raise Exception("Notebook requires 8 devices to run") +``` +++ {"id": "3f37ca93"} +## Intro and a quick example + +By reading this tutorial notebook, you'll learn about `jax.Array`, a unified +datatype for representing arrays, even with physical storage spanning multiple +devices. You'll also learn about how using `jax.Array`s together with `jax.jit` +can provide automatic compiler-based parallelization. + +Before we think step by step, here's a quick example. +First, we'll create a `jax.Array` sharded across multiple devices: + +```{code-cell} +from jax.experimental import mesh_utils +from jax.sharding import PositionalSharding +``` + +```{code-cell} +# Create a Sharding object to distribute a value across devices: +sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) +``` + +```{code-cell} +# Create an array of random values: +x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192)) +# and use jax.device_put to distribute it across devices: +y = jax.device_put(x, sharding.reshape(4, 2)) +jax.debug.visualize_array_sharding(y) +``` + ++++ {"id": "jZ0ZY9Um9Jg4"} + +Next, we'll apply a computation to it and visualize how the result values are +stored across multiple devices too: + +```{code-cell} +z = jnp.sin(y) +jax.debug.visualize_array_sharding(z) +``` + ++++ {"id": "5qccVQoE9tEi"} + +The evaluation of the `jnp.sin` application was automatically parallelized +across the devices on which the input values (and output values) are stored: + +```{code-cell} +# `x` is present on a single device +%timeit -n 5 -r 5 jnp.sin(x).block_until_ready() +``` + +```{code-cell} +# `y` is sharded across 8 devices. +%timeit -n 5 -r 5 jnp.sin(y).block_until_ready() +``` + ++++ {"id": "xWknFQbQ-bzV"} + +Now let's look at each of these pieces in more detail! + + ## `Sharding` describes how array values are laid out in memory across devices +++ {"id": "W6HsXauGxL6w"} @@ -53,19 +114,21 @@ jax.config.update('jax_array', True) +++ {"id": "NWDyp_EjVHkg"} -To parallelize computation across multiple devices, we first have to describe how the data that computation acts on can be distributed across multiple devices. That means describing distributed layouts conveniently. +To parallelize computation across multiple devices, we first must lay out input data across multiple devices. In JAX, `Sharding` objects describe distributed memory layouts. They can be used with `jax.device_put` to produce a value with distributed layout. For example, here's a value with a single-device `Sharding`: ```{code-cell} -:id: s5jXIod7VcWW - import jax x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192)) ``` +```{code-cell} +jax.debug.visualize_array_sharding(x) +``` + +++ {"id": "HhCjhK0zXIqX"} Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring! @@ -73,74 +136,6 @@ Here, we're using the `jax.debug.visualize_array_sharding` function to show wher But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order: ```{code-cell} -:id: 8fc925d2 -:outputId: 40f044de-0a39-46cb-9e61-2277648a9c0e - -jax.debug.visualize_array_sharding(x) -``` - -+++ {"id": "xKC-WWgc8QGo"} - -A quick example of what `jax.Array` can do before we dive into more details: - -```{code-cell} -:id: kiZ59Mho5lzk - -from jax.experimental import mesh_utils -from jax.sharding import PositionalSharding - -# Let's create a Sharding object which we will use to distribute -# a value across devices. -sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) -``` - -```{code-cell} -:id: t_Mw8cxU7YSK -:outputId: b44f6bea-c195-4b9a-d872-815112e8ef2e - -y = jax.device_put(x, sharding.reshape(4, 2)) -jax.debug.visualize_array_sharding(y) -``` - -+++ {"id": "jZ0ZY9Um9Jg4"} - -We distributed `x` which was on a single device before to 8 devices. - -```{code-cell} -:id: WUzDakoG7a2l -:outputId: b575be88-cbb1-4081-fa98-1ba5c6421921 - -z = jnp.sin(y) -jax.debug.visualize_array_sharding(z) -``` - -+++ {"id": "5qccVQoE9tEi"} - -After doing an `jnp.sin` operation on a distributed `jax.Array`, the sharding on the `output` was preserved. Also the operation itself happened on multiple devices. To test that, we can do a timing experiment: - -```{code-cell} -:id: RdfmbYxr-2I_ -:outputId: 50e659d4-9ca0-499e-bd49-66fe59384675 - -# `x` is present on a single device -%timeit -n 5 -r 5 jnp.sin(x).block_until_ready() -``` - -```{code-cell} -:id: cNLJQY_r-W75 -:outputId: da362eae-810a-438a-81e3-66c16698a456 - -# `y` is sharded across 8 devices. -%timeit -n 5 -r 5 jnp.sin(y).block_until_ready() -``` - -+++ {"id": "xWknFQbQ-bzV"} - -We can now continue with the rest of the details of `Sharding`s: - -```{code-cell} -:id: f4d7c00a - from jax.experimental import mesh_utils devices = mesh_utils.create_device_mesh((8,)) ``` @@ -150,9 +145,6 @@ devices = mesh_utils.create_device_mesh((8,)) Then, we create a `PositionalSharding` and use it with `device_put`: ```{code-cell} -:id: K2PL4LwBX0JE -:outputId: 1c3bbe5e-3377-49a4-a8f0-57e5f224a535 - from jax.sharding import PositionalSharding sharding = PositionalSharding(devices) @@ -166,9 +158,6 @@ jax.debug.visualize_array_sharding(x) Here `sharding` is a `PositionalSharding` which acts like an array with sets of devices as elements: ```{code-cell} -:id: d6fd0d23 -:outputId: 2eeea24d-553d-4049-82ba-a431a08f5ac8 - sharding ``` @@ -177,16 +166,10 @@ sharding By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it: ```{code-cell} -:id: b5445d3b -:outputId: 16630351-cb09-42c9-92e5-fe53bc5399d5 - sharding.reshape(8, 1) ``` ```{code-cell} -:id: pS7xTZeBm6Dt -:outputId: 1d799888-7aed-4415-9812-234059321797 - sharding.reshape(4, 2) ``` @@ -202,17 +185,11 @@ def is_congruent(x_shape: Sequence[int], sharding_shape: Sequence[int]) -> bool: For example, we can reshape `sharding` to have shape `(4, 2)`, then use it in a `device_put`: ```{code-cell} -:id: 6JhLL3i_sPth -:outputId: 9cc7ee3e-b15c-435a-8f7d-cb127bf0ce07 - sharding = sharding.reshape(4, 2) print(sharding) ``` ```{code-cell} -:id: 5FCqZfhWt88c -:outputId: e9abbf02-2a9c-4d34-94e0-8d73cfdc96fa - y = jax.device_put(x, sharding) jax.debug.visualize_array_sharding(y) ``` @@ -221,20 +198,14 @@ jax.debug.visualize_array_sharding(y) Here `y` represents the same _value_ as `x`, but its shards (i.e. slices) are stored in different devices' memories. -Different `sharding` shapes result in different distributed layouts (i.e. shardings) of the result: +Different `PositionalSharding` shapes result in different distributed layouts (i.e. shardings) of the result: ```{code-cell} -:id: nt4IbVMkswlO -:outputId: ce106d8a-8ddf-4129-bf5d-1982a28f9160 - sharding = sharding.reshape(1, 8) print(sharding) ``` ```{code-cell} -:id: AyZzDpnFuIpz -:outputId: 7cf42e32-4da6-4693-f44e-eae9ec37f203 - y = jax.device_put(x, sharding) jax.debug.visualize_array_sharding(y) ``` @@ -246,17 +217,11 @@ In some cases, we don't just want to store each slice of `x` in a single device' With `PositionalSharding`, we can express replication by calling the reducer method `replicate`: ```{code-cell} -:id: l5t_Mg_Rux6j -:outputId: 68ae315b-7040-4753-f803-514d88a20333 - sharding = sharding.reshape(4, 2) print(sharding.replicate(axis=0, keepdims=True)) ``` ```{code-cell} -:id: Gi3sDdqAu_8W -:outputId: 3d5b3d72-d68c-481e-854a-b3c2bbea85b6 - y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True)) jax.debug.visualize_array_sharding(y) ``` @@ -265,20 +230,14 @@ jax.debug.visualize_array_sharding(y) Here the visualization shows that `x` is sharded two ways along its second dimension (and not sharded along the first dimension), and each of those shards is replicated four ways (i.e. stored in four device memories). -The `replicate` method acts similar to the familiar NumPy array reduction methods like `.sum()` and `.prod()`. It operates along an axis performing a set union. So if `sharding` has shape `(4, 2)`, then `sharding.replicate(0, keepdims=True)` has shape `(1, 2)`, and `sharding.replicate(1, keepdims=True)` has shape `(4, 1)`. Unlike analogous NumPy methods, `keepdims=True` is actually the default, so reduced-over axes aren't squeezed: +The `replicate` method is analogous to the familiar NumPy array reduction methods like `.sum()` and `.prod()`. It operates along an axis performing a set union. So if `sharding` has shape `(4, 2)`, then `sharding.replicate(0, keepdims=True)` has shape `(1, 2)`, and `sharding.replicate(1, keepdims=True)` has shape `(4, 1)`. Unlike analogous NumPy methods, `keepdims=True` is actually the default, so reduced-over axes aren't squeezed: ```{code-cell} -:id: vDlU8hgJvson -:outputId: 472ddce5-cdec-4434-9a59-995aa1d3cf17 - print(sharding.replicate(0).shape) print(sharding.replicate(1).shape) ``` ```{code-cell} -:id: vHWC4womxCdf -:outputId: e5f9a04a-42c1-4c0d-fafa-66c57cd6f774 - y = jax.device_put(x, sharding.replicate(1)) jax.debug.visualize_array_sharding(y) ``` @@ -294,9 +253,6 @@ So far we've worked with `PositionalSharding`, but there are alternative ways to Another convenient way to express sharding is with the `NamedSharding`: ```{code-cell} -:id: bQCdEAHQ1q8J -:outputId: 41dcae82-22cd-47cc-8f7c-cf0031c1e667 - from jax.experimental import maps from jax.experimental import PartitionSpec from jax.experimental import mesh_utils @@ -314,23 +270,18 @@ jax.debug.visualize_array_sharding(y) We can define a helper function to make things simpler: ```{code-cell} -:id: 2sDUx-VbzvVz - devices = mesh_utils.create_device_mesh((4, 2)) default_mesh = maps.Mesh(devices, axis_names=('a', 'b')) def mesh_sharding( pspec: PartitionSpec, mesh: Optional[maps.Mesh] = None, -) -> NamedSharding: + ) -> NamedSharding: if mesh is None: mesh = default_mesh return NamedSharding(mesh, pspec) ``` ```{code-cell} -:id: KirNGYXLzvK6 -:outputId: 777c9b33-a19c-414d-8e52-df3338bc0411 - y = jax.device_put(x, mesh_sharding(P('a', 'b'))) jax.debug.visualize_array_sharding(y) ``` @@ -340,18 +291,14 @@ jax.debug.visualize_array_sharding(y) Here, we use `P('a', 'b')` to express that the first and second axes of `x` should be sharded over the device mesh axes `'a'` and `'b'`, respectively. We can easily switch to `P('b', 'a')` to shard the axes of `x` over different devices: ```{code-cell} -:id: JJaKU2pJ2eAC -:outputId: ceb3f087-3dbe-4b51-a024-b3af97f12a20 - y = jax.device_put(x, mesh_sharding(P('b', 'a'))) jax.debug.visualize_array_sharding(y) ``` ```{code-cell} -:id: nLlRM7DZ25-K -:outputId: c844464e-fec5-4161-ea6d-731fa0f43740 - -# `None` means that `x` is replicated on the 1st dimension (counting from 0). +# This `None` means that `x` is not sharded on its second dimension, +# and since the Mesh axis name 'b' is not mentioned, shards are +# replicated across it. y = jax.device_put(x, mesh_sharding(P('a', None))) jax.debug.visualize_array_sharding(y) ``` @@ -363,17 +310,11 @@ Here, because `P('a', None)` doesn't mention the `Mesh` axis name `'b'`, we get To shard only over the second axis of `x`, we can use a `None` placeholder in the `PartitionSpec`: ```{code-cell} -:id: svq_HGHU29HV -:outputId: 318302c1-c865-4e7c-db05-26fc3a435e78 - y = jax.device_put(x, mesh_sharding(P(None, 'b'))) jax.debug.visualize_array_sharding(y) ``` ```{code-cell} -:id: oRhRKETlX0RD -:outputId: 966287ce-de6f-4273-fe70-3f4ab5a74bfd - y = jax.device_put(x, mesh_sharding(P(None, 'a'))) jax.debug.visualize_array_sharding(y) ``` @@ -383,9 +324,6 @@ jax.debug.visualize_array_sharding(y) For a fixed mesh, we can even partition one logical axis of `x` over multiple device mesh axes: ```{code-cell} -:id: ldq5Ws2A3Bbl -:outputId: bd8d18e6-81c0-40b3-b9ae-e1315877d819 - y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None))) jax.debug.visualize_array_sharding(y) ``` @@ -400,22 +338,17 @@ Using `NamedSharding` makes it easy to define a device mesh once and give its ax +++ {"id": "JukoaRhl4tXJ"} -With sharded data, the compiler can give us parallel computation. In particular, functions decorated with `jax.jit` can operate over sharded arrays without copying data onto a single device. Instead, computation follows sharding: based on the sharding of the input data, the compiler decides shardings for intermediates and output values, and parallelizes their evaluation, even inserting communication operations as necessary. +With sharded input data, the compiler can give us parallel computation. In particular, functions decorated with `jax.jit` can operate over sharded arrays without copying data onto a single device. Instead, computation follows sharding: based on the sharding of the input data, the compiler decides shardings for intermediates and output values, and parallelizes their evaluation, even inserting communication operations as necessary. For example, the simplest computation is an elementwise one: ```{code-cell} -:id: _NqZnEUHgZQv - from jax.experimental import mesh_utils from jax.sharding import PositionalSharding sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) ``` ```{code-cell} -:id: x89raigTazVJ -:outputId: 7d59f9fe-6509-416e-ebf3-9bad49146435 - x = jax.device_put(x, sharding.reshape(4, 2)) print('input sharding:') jax.debug.visualize_array_sharding(x) @@ -434,9 +367,6 @@ In other words, even though we wrote the `jnp.sin` computation as if a single ma We can do the same for more than just elementwise operations too. Consider a matrix multiplication with sharded inputs: ```{code-cell} -:id: D52tW3y-cx32 -:outputId: dd895028-1a70-485b-ef46-7accb43b261b - y = jax.device_put(x, sharding.reshape(4, 2).replicate(1)) z = jax.device_put(x, sharding.reshape(4, 2).replicate(0)) print('lhs sharding:') @@ -456,32 +386,20 @@ Here the compiler chose the output sharding so that it could maximally paralleli How can we be sure it's actually running in parallel? We can do a simple timing experiment: ```{code-cell} -:id: BUcN-RqtfRml -:outputId: 67e49f89-f76a-4e38-a231-00c76886a90b - x_single = jax.device_put(x, jax.devices()[0]) jax.debug.visualize_array_sharding(x_single) ``` ```{code-cell} -:id: iKrmBxJ-fhM9 -:outputId: 743ebbfe-ef93-4cd5-e030-694e7018f781 - np.allclose(jnp.dot(x_single, x_single), jnp.dot(y, z)) ``` ```{code-cell} -:id: gpcGJ1PSfSAV -:outputId: 463083fd-2e19-4bf8-fb4b-3e5f42934b91 - %timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready() ``` ```{code-cell} -:id: 1LMWZuYRfSGT -:outputId: 5c98dc6e-9ddf-4176-8b82-fd2d16517ad2 - %timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready() ``` @@ -490,9 +408,6 @@ np.allclose(jnp.dot(x_single, x_single), Even copying a sharded `Array` produces a result with the sharding of the input: ```{code-cell} -:id: sdhFK3VGntbc -:outputId: 662e6811-354b-4d26-be10-eb65ffdf21fa - w_copy = jnp.copy(w) jax.debug.visualize_array_sharding(w_copy) ``` @@ -509,8 +424,6 @@ But what if two arguments to a computation are explicitly placed on different se In these ambiguous cases, an error is raised: ```{code-cell} -:id: 9xmq1Jbatxwz - import textwrap from termcolor import colored @@ -520,9 +433,6 @@ def print_exception(e): ``` ```{code-cell} -:id: Yah71IjBqyKD -:outputId: fd28bfaa-bf1c-4236-9c9b-fbf32ca3a8b9 - sharding1 = PositionalSharding(jax.devices()[:4]) sharding2 = PositionalSharding(jax.devices()[4:]) @@ -533,9 +443,6 @@ except ValueError as e: print_exception(e) ``` ```{code-cell} -:id: HSHDAuJDqyO3 -:outputId: a2b76cea-6944-4323-873b-1ba4ced9baa8 - devices = jax.devices() permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]] @@ -558,9 +465,6 @@ Unlike committed arrays, uncommitted arrays can be moved and resharded automatic For example, the output of `jnp.zeros`, `jnp.arange`, and `jnp.array` are uncommitted: ```{code-cell} -:id: hDa3ogiwrvZx -:outputId: f2c186ee-4626-4352-f9af-794c042e2b89 - y = jax.device_put(x, sharding1.reshape(4, 2)) y + jnp.ones_like(y) y + jnp.arange(y.size).reshape(y.shape) @@ -576,27 +480,20 @@ print('no error!') While the compiler will attempt to decide how a function's intermediate values and outputs should be sharded, we can also give it hints using `jax.lax.with_sharding_constraint`. Using `jax.lax.with_sharding_constraint` is much like `jax.device_put`, except we use it inside staged-out (i.e. `jit`-decorated) functions: ```{code-cell} -:id: _acYTsBsxpyP - +# TODO(mattjj,yashkatariya): remove cell when with_sharding_constraint is in jax.lax jax.lax.with_sharding_constraint = jax.experimental.pjit.with_sharding_constraint ``` ```{code-cell} -:id: KIi13NFHxz77 - sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) ``` ```{code-cell} -:id: UTRs-Zf2x8oJ - x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192)) x = jax.device_put(x, sharding.reshape(4, 2)) ``` ```{code-cell} -:id: fkm3YKfZwkWt - @jax.jit def f(x): x = x + 1 @@ -605,17 +502,12 @@ def f(x): ``` ```{code-cell} -:id: tIglE_fayQqw -:outputId: ebcf2cb5-e67f-4d53-cc29-10212a88e127 - jax.debug.visualize_array_sharding(x) y = f(x) jax.debug.visualize_array_sharding(y) ``` ```{code-cell} -:id: DzuzKQhG2QLE - @jax.jit def f(x): x = x + 1 @@ -624,9 +516,6 @@ def f(x): ``` ```{code-cell} -:id: lmDxCQ1W2TlD -:outputId: 7664dabd-9ba2-4d53-a7da-55ead835c3f5 - jax.debug.visualize_array_sharding(x) y = f(x) jax.debug.visualize_array_sharding(y) @@ -644,24 +533,18 @@ It's often a good practice to annotate the outputs of computations, for example +++ {"id": "g7y0OJBSGoSW"} -**⚠️ WARNING: The following is meant to be a demonstration of automatic sharding propagation with `jax.Array`, but it is _not a recommended practice_.** - -For neural network training, it is a good idea to constraint parameter and input sharding to be the same at every training step which can be achieved e.g. using `pjit`'s `out_axis_resources` parameter and `with_sharding_constraint`. +**⚠️ WARNING: The following is meant to be a simple demonstration of automatic sharding propagation with `jax.Array`, but it may not reflect best practices for real examples.** For instance, real examples may require more use of `with_sharding_constraint`. +++ {"id": "3ii_UPkG3gzP"} We can use `jax.device_put` and `jax.jit`'s computation-follows-sharding features to parallelize computation in neural networks. Here are some simple examples, based on this basic neural network: ```{code-cell} -:id: sDAeJoNp_VyP - import jax import jax.numpy as jnp ``` ```{code-cell} -:id: t-J6YtpA2db0 - def predict(params, inputs): for W, b in params: outputs = jnp.dot(inputs, W) + b @@ -675,15 +558,11 @@ def loss(params, batch): ``` ```{code-cell} -:id: 4USnNl6w4Y1K - loss_jit = jax.jit(loss) gradfun = jax.jit(jax.grad(loss)) ``` ```{code-cell} -:id: nfqG0N1g2dhk - def init_layer(key, n_in, n_out): k1, k2 = jax.random.split(key) W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in) @@ -711,29 +590,19 @@ params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size) ### 8-way batch data parallelism ```{code-cell} -:id: uxZ4Czqyzrc5 - sharding = PositionalSharding(jax.devices()).reshape(8, 1) ``` ```{code-cell} -:id: q9maIR6K4T9r - batch = jax.device_put(batch, sharding) params = jax.device_put(params, sharding.replicate()) ``` ```{code-cell} -:id: CtKIMM6ry7Ov -:outputId: 336c1892-a3e1-4a2a-a0ee-a16dd7a24103 - loss_jit(params, batch) ``` ```{code-cell} -:id: tAM6NQkly8lw -:outputId: 8e5caf04-3c38-426b-8e8a-823760d6688c - step_size = 1e-5 for _ in range(30): @@ -745,23 +614,15 @@ print(loss_jit(params, batch)) ``` ```{code-cell} -:id: Eix05eVQy-LZ -:outputId: 8f0d79b4-9ff3-47a9-d25f-ecc770e72195 - %timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready() ``` ```{code-cell} -:id: W-19ajlSy_gF - batch_single = jax.device_put(batch, jax.devices()[0]) params_single = jax.device_put(params, jax.devices()[0]) ``` ```{code-cell} -:id: DBHfeKyUzBD9 -:outputId: a96ee2c8-179d-4dfa-aa72-5c0076ff6881 - %timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready() ``` @@ -770,23 +631,16 @@ params_single = jax.device_put(params, jax.devices()[0]) ### 4-way batch data parallelism and 2-way model tensor parallelism ```{code-cell} -:id: gw1WZyXu4owx - sharding = sharding.reshape(4, 2) ``` ```{code-cell} -:id: P0s_ibu8z0hW -:outputId: c9bafdc7-a811-4a38-db6b-89d55355a241 - batch = jax.device_put(batch, sharding.replicate(1)) jax.debug.visualize_array_sharding(batch[0]) jax.debug.visualize_array_sharding(batch[1]) ``` ```{code-cell} -:id: 7kNJVPBjz5nq - (W1, b1), (W2, b2), (W3, b3), (W4, b4) = params W1 = jax.device_put(W1, sharding.replicate()) @@ -805,29 +659,18 @@ params = (W1, b1), (W2, b2), (W3, b3), (W4, b4) ``` ```{code-cell} -:id: I8ZJiiGb0HJk -:outputId: 5c406158-6b71-45ec-bdad-2d6b39b237ca - jax.debug.visualize_array_sharding(W2) ``` ```{code-cell} -:id: t2fsJ_Ow0LgK -:outputId: 50ad8513-cb50-44d3-e0df-551ddd3c9fcb - jax.debug.visualize_array_sharding(W3) ``` ```{code-cell} -:id: xnNgGB7-0Nh4 -:outputId: 338d8bd4-0e0d-4a4a-99f7-e408b95ab274 - print(loss_jit(params, batch)) ``` ```{code-cell} -:id: ygV3-IBV0Qx3 - step_size = 1e-5 for _ in range(30): @@ -837,25 +680,16 @@ for _ in range(30): ``` ```{code-cell} -:id: VWXN24Xh0Tkc -:outputId: 13379078-3863-42f3-e5f4-368095c3a35c - print(loss_jit(params, batch)) ``` ```{code-cell} -:id: Cq3TzYU70Vfd -:outputId: 9aae2051-3e4f-4918-c221-81d21edb563a - (W1, b1), (W2, b2), (W3, b3), (W4, b4) = params jax.debug.visualize_array_sharding(W2) jax.debug.visualize_array_sharding(W3) ``` ```{code-cell} -:id: hAeLBs9D0Z8T -:outputId: 3eeacdd1-b44b-46e7-9868-dbdbca4a78e2 - %timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready() ``` @@ -878,8 +712,6 @@ However, the existing stable RNG implementation is not automatically partitionab Consider the following example, where a function draws random uniform numbers and adds them to the input, elementwise: ```{code-cell} -:id: quUwyGoiHub2 - @jax.jit def f(key, x): numbers = jax.random.uniform(key, x.shape) @@ -895,9 +727,6 @@ x = jax.device_put(jnp.arange(24), x_sharding) On a partitioned input, the function `f` produces output that is also partitioned: ```{code-cell} -:id: nxPHw0loLgFd -:outputId: 59514362-61c5-4181-c7e9-ad7777218779 - jax.debug.visualize_array_sharding(f(key, x)) ``` @@ -906,9 +735,6 @@ jax.debug.visualize_array_sharding(f(key, x)) But if we inspect the compiled computation for `f` on this partitioned input, we see that it does involve some communication: ```{code-cell} -:id: uf6hTkjvKifH -:outputId: 3af4c656-0753-4dd5-e4be-0097bba43256 - f_exe = f.lower(key, x).compile() print('Communicating?', 'collective-permute' in f_exe.as_text()) ``` @@ -918,9 +744,6 @@ print('Communicating?', 'collective-permute' in f_exe.as_text()) One way to work around this is to configure JAX with the experimental upgrade flag `jax_threefry_partitionable`. With the flag on, the "collective permute" operation is now gone from the compiled computation: ```{code-cell} -:id: G87r_Aq6Ts_F -:outputId: 5c8cd729-5f2a-450d-ee40-fb788f91707a - jax.config.update('jax_threefry_partitionable', True) f_exe = f.lower(key, x).compile() print('Communicating?', 'collective-permute' in f_exe.as_text()) @@ -931,20 +754,14 @@ print('Communicating?', 'collective-permute' in f_exe.as_text()) The output is still partitioned: ```{code-cell} -:id: 8RplTPyRSTbW -:outputId: 74e2ef51-c5c0-4c25-cc90-2f465961bdae - jax.debug.visualize_array_sharding(f(key, x)) ``` +++ {"id": "kaK--hPmSPpV"} -One caveat, however, is that _the random values produced may be different than before_, even though they were generated by the same random key: +One caveat to the `jax_threefry_partitionable` option, however, is that _the random values produced may be different than without the flag set_, even though they were generated by the same random key: ```{code-cell} -:id: f_EjYjOpSO18 -:outputId: b50278f2-927d-4aea-fb04-cdd5c2fc1a66 - jax.config.update('jax_threefry_partitionable', False) print('Stable:') print(f(key, x))