diff --git a/docs/advanced_guide.rst b/docs/advanced_guide.rst index 029f72e83..eddca16d5 100644 --- a/docs/advanced_guide.rst +++ b/docs/advanced_guide.rst @@ -18,6 +18,7 @@ This section contains examples and tutorials on more advanced topics, such as Mu multi_process notebooks/Distributed_arrays_and_automatic_parallelization + notebooks/shard_map notebooks/xmap_tutorial .. toctree:: diff --git a/docs/conf.py b/docs/conf.py index 917382d67..5becd88de 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -209,6 +209,7 @@ nb_execution_excludepatterns = [ 'notebooks/xmap_tutorial.*', 'notebooks/Distributed_arrays_and_automatic_parallelization.*', 'notebooks/autodiff_remat.*', + 'notebooks/shard_map.*', # Requires accelerators 'pallas/quickstart.*', ] diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb new file mode 100644 index 000000000..b4ce840bb --- /dev/null +++ b/docs/notebooks/shard_map.ipynb @@ -0,0 +1,1964 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "97c57a94", + "metadata": {}, + "source": [ + "# Intro\n", + "\n", + "`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n", + "\n", + "`shard_map` is complementary to, and comopsable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", + "\n", + "If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))\n", + "\n", + "By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies.\n", + "\n", + "## So, let's see a `shard_map`!\n", + "\n", + "Without further ado, here's a toy example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3f562ec", + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "from jax.sharding import Mesh, PartitionSpec as P\n", + "from jax.experimental import mesh_utils\n", + "from jax.experimental.shard_map import shard_map" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64a910c1", + "metadata": {}, + "outputs": [], + "source": [ + "devices = mesh_utils.create_device_mesh((4, 2))\n", + "mesh = Mesh(devices, axis_names=('x', 'y'))\n", + "\n", + "a = jnp.arange( 8 * 16.).reshape(8, 16)\n", + "b = jnp.arange(16 * 4.).reshape(16, 4)\n", + "\n", + "@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),\n", + " out_specs=P('x', None))\n", + "def matmul_basic(a_block, b_block):\n", + " # a_block: f32[2, 8]\n", + " # b_block: f32[8, 4]\n", + " c_partialsum = jnp.dot(a_block, b_block)\n", + " c_block = jax.lax.psum(c_partialsum, 'y')\n", + " # c_block: f32[2, 4]\n", + " return c_block\n", + "\n", + "c = matmul_basic(a, b) # c: f32[8, 4]" + ] + }, + { + "cell_type": "markdown", + "id": "85c92753", + "metadata": {}, + "source": [ + "This function computes a matrix multiply in parallel by performing local block matrix multipiles followed by a collective sum operation. We can check the result is correct:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09c382bb", + "metadata": {}, + "outputs": [], + "source": [ + "from jax.tree_util import tree_map, tree_all\n", + "\n", + "def allclose(a, b):\n", + " return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))\n", + "\n", + "allclose(c, jnp.dot(a, b))" + ] + }, + { + "cell_type": "markdown", + "id": "7618d7f8", + "metadata": {}, + "source": [ + "The result is sharded along its rows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "932af90f", + "metadata": {}, + "outputs": [], + "source": [ + "jax.debug.visualize_array_sharding(c)" + ] + }, + { + "cell_type": "markdown", + "id": "a79329b0", + "metadata": {}, + "source": [ + "At a high level, `shard_map` is kind of like `vmap` or `pmap`, in that we're\n", + "mapping a function over pieces of array data, but notice that\n", + "* `shard_map` slices up inputs into blocks (and the output is formed by concatenating result blocks), keeping the rank the same, whereas `vmap` would reduce the rank by mapping away an axis;\n", + "* the `mesh` argument lets us control precise device placement of computation and results;\n", + "* we're mapping over multiple data axes at once, and setting up multiple axis names for collectives (both `'x'` and `'y'` here);\n", + "* since we're not using `jax.jit` yet, everything is eagerly evaluated, and we can even `print` intermediate values for debugging.\n", + "\n", + "The above code is performing the same computation as this `jax.jit` automatic parallelization code:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9cf3cdd", + "metadata": {}, + "outputs": [], + "source": [ + "from jax.sharding import NamedSharding\n", + "\n", + "a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))\n", + "b = jax.device_put(b, NamedSharding(mesh, P('y', None)))\n", + "\n", + "@jax.jit\n", + "def matmul_reference(a, b):\n", + " c = jnp.dot(a, b)\n", + " return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))\n", + "\n", + "c_ref = matmul_reference(a, b)\n", + "allclose(c_ref, jnp.dot(a, b))" + ] + }, + { + "cell_type": "markdown", + "id": "ab826433", + "metadata": {}, + "source": [ + "We can think of `shard_map` as performing a `device_put` or\n", + "`with_sharding_constraint` on its inputs according to its `mesh` and `in_specs`\n", + "arguments, so the blocks over which `matmul_basic` operates are the same as in\n", + "`matmul_reference`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8497b5f8", + "metadata": {}, + "outputs": [], + "source": [ + "print('a blocks:'); jax.debug.visualize_array_sharding(a)\n", + "print('b blocks:'); jax.debug.visualize_array_sharding(b)\n", + "print('c blocks:'); jax.debug.visualize_array_sharding(c)" + ] + }, + { + "cell_type": "markdown", + "id": "532fe5f6", + "metadata": {}, + "source": [ + "## Slow down, start with the basics!\n", + "\n", + "### Rank-reducing vs rank-preserving maps\n", + "\n", + "We can think of `vmap` and `pmap` as unstacking each array input along an axis\n", + "(e.g. unpacking a 2D matrix into its 1D rows), applying its body function to\n", + "each piece, and stacking the results back together, at least when collectives\n", + "aren't involved:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb8e1883", + "metadata": {}, + "outputs": [], + "source": [ + "def check_vmap(f, xs):\n", + " ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)\n", + " expected = jnp.stack([f(x) for x in xs]) # vmap reference semantics\n", + " print(allclose(ans, expected))\n", + "\n", + "check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))" + ] + }, + { + "cell_type": "markdown", + "id": "9d55b900", + "metadata": {}, + "source": [ + "For example, if `xs` had shape `f32[8,5]` then each `x` would have shape\n", + "`f32[5]`, and if each `f(x)` had shape `f32[3,7]` then the final stacked result\n", + "`vmap(f)(xs)` would have shape `f32[8,3,7]`. That is, each application of the\n", + "body function `f` takes as argument inputs with one fewer axis than the\n", + "corresponding argument to `vmap(f)`. We can say these are _rank-reducing maps_\n", + "with unstacking/stacking of inputs/outputs.\n", + "\n", + "The number of logical applications of `f`, or _instances_ of `f`, is determined\n", + "by the size of the input axis being mapped over: for example, if we map over an\n", + "input axis of size 8, semantically we get 8 logical applications of the\n", + "function.\n", + "\n", + "In contrast, `shard_map` does not have this rank-reducing behavior. Instead, we\n", + "can think of it as slicing (or \"unconcatenating\") along input axes into blocks,\n", + "applying the body function, and concatenating the results back together (again\n", + "when collectives aren't involved):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30e89f3f", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "devices = np.array(jax.devices()[:4])\n", + "mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4\n", + "\n", + "def check_shmap(f, y):\n", + " ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)\n", + " expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])\n", + " print(allclose(ans, expected))\n", + "\n", + "check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))" + ] + }, + { + "cell_type": "markdown", + "id": "985ff202", + "metadata": {}, + "source": [ + "Recall that jnp.split slices its input into equally-sized blocks with the same\n", + "rank, so that if in the above example `y` had shape `f32[8,5]` then each\n", + "`y_blk` would have shape `f32[2,5]`, and if each `f(y_blk)` had shape\n", + "`f32[3,7]` then the final concatenated result `shard_map(f, ...)(y)` would have\n", + "shape `f32[12,7]`. So `shard_map` maps over _shards_, or blocks, of its inputs.\n", + "We can say it's a _rank-preserving map_ with unconcatenating/concatenating of\n", + "its inputs/outputs.\n", + "\n", + "The number of logical applications of `f` is determined by the mesh size, not\n", + "by any input axis size: for example, if we have a mesh of total size 4 (i.e.\n", + "over 4 devices) then semantically we get 4 logical applications of the\n", + "function, corresponding to the 4 devices physically computing them.\n", + "\n", + "### Controlling how each input is split (unconcatenated) and tiled with `in_specs`\n", + "\n", + "Each of the `in_specs` identifies some of the corresponding input array's axes\n", + "with mesh axes by name using `PartitionSpec`s, representing how to split (or\n", + "unconcatenate) that input into the blocks to which the body function is\n", + "applied. That identification determines the shard sizes; when an input axis is\n", + "identified with a mesh axis, the input is split (unconcatenated) along that\n", + "logical axis into a number of pieces equal to the corresponding mesh axis size.\n", + "(It's an error if the corresponding mesh axis size does not evenly divide the\n", + "input array axis size.) If an input's pspec does not mention a mesh axis name,\n", + "then there's no splitting over that mesh axis. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08555009", + "metadata": {}, + "outputs": [], + "source": [ + "devices = mesh_utils.create_device_mesh((4, 2))\n", + "mesh = Mesh(devices, ('i', 'j'))\n", + "\n", + "@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n", + "def f1(x_block):\n", + " print(x_block.shape) # prints (3, 12)\n", + " return x_block\n", + "\n", + "x1 = jnp.arange(12 * 12).reshape(12, 12)\n", + "y = f1(x1)" + ] + }, + { + "cell_type": "markdown", + "id": "f191681b", + "metadata": {}, + "source": [ + "Here, because the input pspec did not mention the mesh axis name `'j'`, no\n", + "input array axis is split over that mesh axis; similarly, because the second\n", + "axis of the input array is not identified with (and hence split over) any mesh\n", + "axis, application of `f1` gets a full view of the input along that axis.\n", + "\n", + "When a mesh axis is not mentioned in an input pspec, we can always rewrite to a\n", + "less efficient program where all mesh axes are mentioned but the caller\n", + "performs a `jnp.tile`, for example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2515872d", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))\n", + "def f2(x_block):\n", + " print(x_block.shape)\n", + " return x_block\n", + "\n", + "x = jnp.arange(12 * 12).reshape(12, 12)\n", + "x_ = jnp.tile(x, (1, mesh.shape['j'])) # x_ has shape (12, 24)\n", + "y = f2(x_) # prints (3,12), and f1(x) == f2(x_)" + ] + }, + { + "cell_type": "markdown", + "id": "8be0595e", + "metadata": {}, + "source": [ + "In other words, because each input pspec can mention each mesh axis name zero\n", + "or one times, rather than having to mention each name exactly once, we can say\n", + "that in addition to the `jnp.split` built into its input, `shard_map` also has\n", + "a `jnp.tile` built into its input, at least logically (though the tiling may\n", + "not need to be carried out physically, depending on the arguments' physical\n", + "sharding layout). The tiling to use is not unique; we could also have tiled\n", + "along the first axis, and used the pspec `P(('j', 'i'), None)`.\n", + "\n", + "Physical data movement is possible on inputs, as each device needs to have a\n", + "copy of the appropriate data.\n", + "\n", + "### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs`\n", + "\n", + "Analogously to the input side, each of the `out_specs` identifies some of the\n", + "corresponding output array's axes with mesh axes by name, representing how the\n", + "output blocks (one for each application of the body function, or equivalently\n", + "one for each physical device) should be assembled back together to form the\n", + "final output value. For example, in both the `f1` and `f2` examples above the\n", + "`out_specs` indicate we should form the final output by concatenating together\n", + "the block results along both axes, resulting in both cases an array `y` of\n", + "shape `(12, 24)`. (It's an error if an output shape of the body function, i.e.\n", + "an output block shape, has a rank too small for the concatenation described by\n", + "the corresponding output pspec.)\n", + "\n", + "When a mesh axis name is not mentioned in an output pspec, it represents an\n", + "un-tiling: when the user writes an output pspec which does not mention one of\n", + "the mesh axis names, they promise that the output blocks are equal along that\n", + "mesh axis, and so only one block along that axis is used in the output (rather\n", + "than concatenating all the blocks together along that mesh axis). For example,\n", + "using the same mesh as above:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42fa392a", + "metadata": {}, + "outputs": [], + "source": [ + "x = jnp.array([[3.]])\n", + "\n", + "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()\n", + "print(z) # prints the same as jnp.tile(x, (4, 2))\n", + "\n", + "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()\n", + "print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))\n", + "\n", + "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()\n", + "print(z) # prints the same as jnp.tile(x, (1, 1)), or just x" + ] + }, + { + "cell_type": "markdown", + "id": "9c25db8c", + "metadata": {}, + "source": [ + "The body function closing over an array value is equivalent to passing it as an\n", + "augment with a corresponding input pspec of P(None, None). As another example,\n", + "following more closely to the other examples above:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3c0a0e3", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))\n", + "def f3(x_block):\n", + " return jax.lax.psum(x_block, 'j')\n", + "\n", + "x = jnp.arange(12 * 12).reshape(12, 12)\n", + "y3 = f3(x)\n", + "print(y3.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "0adc8960", + "metadata": {}, + "source": [ + "The result has a second axis size of 6, half the size of the input's second\n", + "axis. In this case, the un-tile expressed by not mentioning the mesh axis name\n", + "`'j'` in the output pspec was safe because of the collective `psum`, which\n", + "ensures each output block is equal along the corresponding mesh axis. Here are\n", + "two more examples where we vary which mesh axes are mentioned in the output\n", + "pspec:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b65636dd", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", + "def f4(x_block):\n", + " return jax.lax.psum(x_block, 'i')\n", + "\n", + "x = jnp.arange(12 * 12).reshape(12, 12)\n", + "y4 = f4(x)\n", + "print(y4.shape) # (3,12)\n", + "\n", + "\n", + "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))\n", + "def f5(x_block):\n", + " return jax.lax.psum(x_block, ('i', 'j'))\n", + "\n", + "y5 = f5(x)\n", + "print(y5.shape) # (3,6)" + ] + }, + { + "cell_type": "markdown", + "id": "39218e05", + "metadata": {}, + "source": [ + "On the physical side, not mentioning a mesh axis name in an output pspec\n", + "assembles an `Array` from the output device buffers with replicated layout\n", + "along that mesh axis.\n", + "\n", + "There is no runtime check that the output blocks are actually equal along a\n", + "mesh axis to be un-tiled along, or equivalently that the corresponding physical\n", + "buffers have equal values and thus can be interpreted as a replicated layout\n", + "for a single logical array. But we can provide a static check mechanism which\n", + "raises an error on all potentially-incorrect programs.\n", + "\n", + "Because the `out_specs` can mention mesh axis names zero or one times, and\n", + "because they can be mentioned in any order, we can say that in addition to the\n", + "`jnp.concatenate` built into its output, `shard_map` also has both an _untile_\n", + "and a _block transpose_ built into its output.\n", + "\n", + "Physical data movement is not possible on outputs, no matter the output pspec.\n", + "Instead, `out_specs` just encodes how to assemble the block outputs into\n", + "`Array`s, or physically how to interpret the buffers across devices as the\n", + "physical layout of a single logical `Array`.\n", + "\n", + "# API Specification\n", + "\n", + "```python\n", + "from jax.sharding import Mesh\n", + "Specs = PyTree[PartitionSpec]\n", + "\n", + "def shard_map(\n", + " f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,\n", + " auto: collections.abc.Set[AxisName] = frozenset([]),\n", + " check_rep: bool = True,\n", + ") -> Callable:\n", + " ...\n", + "```\n", + "where:\n", + "* communication collectives like `psum` in the body of `f` can mention the axis names of `mesh`;\n", + "* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`;\n", + "* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n", + "* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the obdy, as in the caller, rather than manually;\n", + "* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)).\n", + "\n", + "The shapes of the arguments passed to `f` have the same ranks as the arguments\n", + "passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed\n", + "from the shape `shape` of the corresponding argument to `shard_map`-of-`f` and\n", + "the corresponding `PartitionSpec` `spec` as roughly\n", + "`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`.\n", + "\n", + "# Collectives tutorial\n", + "\n", + "A `shard_map` need not be a pure map: function applications can communicate\n", + "with each other via _collectives_, using axis names defined in the `mesh`\n", + "argument.\n", + "\n", + "Recall that `shard_map` maps a function over shards, or blocks, of input data,\n", + "so that this:\n", + "\n", + "```python\n", + "mesh = Mesh(jax.devices(), ('i',))\n", + "x = jnp.arange(16.)\n", + "f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))\n", + "y = f_shmapped(x)\n", + "```\n", + "\n", + "Computes the same values, evaluating applications of `f` to the same argument\n", + "values, as this reference function:\n", + "\n", + "```python\n", + "def f_shmapped_ref(x):\n", + " x_blocks = jnp.array_split(x, mesh.shape[0])\n", + " y_blocks = [f(x_blk) for x_blk in x_blocks]\n", + " return jnp.concatenate(y_blocks)\n", + "```\n", + "\n", + "We call these applications of `f` to different argument shards _function\n", + "instances_. Each function instance is executed on a different device (or subset\n", + "of devices).\n", + "\n", + "These reference semantics work when `f` has no communication collectives in\n", + "it. But what if we want the function instances to communicate, corresponding\n", + "to having cross-device communication? That is, what are the reference\n", + "semantics when `f` contains a collective? Say `f` has just one collective, and\n", + "is of the form\n", + "\n", + "```python\n", + "def f(x_blk):\n", + " z_blk = f_part1(x_blk)\n", + " u_blk = collective(z_blk, axis_name)\n", + " v_blk = f_part2(x_blk, z_blk, u_blk)\n", + " return v_blk\n", + "```\n", + "\n", + "where we're assuming there's only one mesh axis we're mapping over, and\n", + "`axis_name` is the corresponding name for it. Then the reference semantics\n", + "would look more like:\n", + "\n", + "```python\n", + "def f_shmapped_ref(x):\n", + " x_blocks = jnp.array_split(x, mesh.shape[0])\n", + " z_blocks = [f_part1(x_blk) for x_blk in x_blocks]\n", + " u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))]\n", + " v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk\n", + " in zip(x_blocks, z_blocks, u_blocks)]\n", + " return jnp.concatenate(v_blocks)\n", + "```\n", + "\n", + "Notice that `collective_ref` might depend on all the `z_blocks`. That is,\n", + "while `f_part1` and `f_part2` are mapped over blocks independently, a\n", + "collective introduces some amount of cross-block dependence. Physically, that\n", + "means communication across devices. Exactly what communication happens, and\n", + "what values are computed, depend on the collective.\n", + "\n", + "## `psum`\n", + "\n", + "The simplest collective may be `jax.lax.psum`, which computes an\n", + "all-reduce-sum along a device mesh axis (or multiple axes).\n", + "Here's a toy example:\n", + "\n", + "\"Illustration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dac420e5", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax import lax\n", + "\n", + "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", + "from jax.experimental.shard_map import shard_map" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38dd14b6", + "metadata": {}, + "outputs": [], + "source": [ + "mesh1d = Mesh(jax.devices()[:4], ('i',))\n", + "\n", + "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n", + "def f1(x_block):\n", + " print('BEFORE:\\n', x_block)\n", + " y_block = jax.lax.psum(x_block, 'i')\n", + " print('AFTER:\\n', y_block)\n", + " return y_block" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22f0b947", + "metadata": {}, + "outputs": [], + "source": [ + "x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])\n", + "y = f1(x)\n", + "print('FINAL RESULT:\\n', y)" + ] + }, + { + "cell_type": "markdown", + "id": "e32df19e", + "metadata": {}, + "source": [ + "The prints show that each function application starts with its own chunk of\n", + "the argument value `x_block`. After the `psum`, each function application has\n", + "the same value of `y_block`, computed by summing the applications' `x_block`\n", + "values together.\n", + "\n", + "In the case where there's a single axis name in the computation, we could say\n", + "that the `collective_ref` reference implementation for `psum` is\n", + "\n", + "```python\n", + "def psum_ref(_, x_blocks):\n", + " tot = sum(x_blocks)\n", + " return [tot] * len(x_blocks)\n", + "```\n", + "\n", + "Notice also that because `f1` returns `y_block`, the result of a `psum` over\n", + "`'i'`, we can use `out_specs=P()` so the caller gets a single logical copy of\n", + "the result value, rather than a tiled result.\n", + "\n", + "When there is more than one mesh axis, we can perform a `psum` over\n", + "each one separately, or over multiple axes at once:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39fba427", + "metadata": {}, + "outputs": [], + "source": [ + "mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))\n", + "\n", + "@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", + "def f2(x_block):\n", + " print('BEFORE:\\n', x_block)\n", + " y_block = jax.lax.psum(x_block, 'i')\n", + " print('AFTER:\\n', y_block)\n", + " return y_block\n", + "\n", + "y = f2(jnp.arange(16).reshape(4, 4))\n", + "print('FINAL RESULT:\\n', y)" + ] + }, + { + "cell_type": "markdown", + "id": "d2bdd59b", + "metadata": {}, + "source": [ + "By applying a `psum` over mesh axis `'i'`, we get values of `y_block` which\n", + "are equal along axis '`i'`, but not axis `'j'`. (So we can use\n", + "`out_specs=P(None, 'j')` to get a single logical result along that axis.)\n", + "\n", + "If we apply the `psum` over both axes, the `y_block` value is equal along both\n", + "axes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2919056c", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))\n", + "def f3(x_block):\n", + " print('BEFORE:\\n', x_block)\n", + " y_block = jax.lax.psum(x_block, ('i', 'j'))\n", + " print('AFTER:\\n', y_block)\n", + " return y_block\n", + "\n", + "y = f3(jnp.arange(16).reshape(4, 4))\n", + "print('FINAL RESULT:\\n', y)" + ] + }, + { + "cell_type": "markdown", + "id": "e9d1f748", + "metadata": {}, + "source": [ + "In machine learning, we often use `psum` to compute total losses or, when we\n", + "have a `grad` inside the `shard_map`ped function body, total gradients.\n", + "\n", + "In the sequel, we'll see how `psum` can be implemented in terms of other\n", + "primitives, which gives some intuition about its communication cost.\n", + "\n", + "## `all_gather`\n", + "\n", + "Another fundamental operation is gathering array shards along an axis, so that\n", + "each function application has a full copy of the data along that axis:\n", + "\n", + "\"Illustration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac45aafc", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "def f4(x_block):\n", + " print('BEFORE:\\n', x_block)\n", + " y_block = jax.lax.all_gather(x_block, 'i', tiled=True)\n", + " print('AFTER:\\n', y_block)\n", + " return y_block\n", + "\n", + "x = jnp.array([3, 9, 5, 2])\n", + "y = f4(x)\n", + "print('FINAL RESULT:\\n', y)" + ] + }, + { + "cell_type": "markdown", + "id": "a16048aa", + "metadata": {}, + "source": [ + "The prints show that each function application again starts with its own chunk\n", + "of the argument value `x_block`. After the `all_gather`, they have a common\n", + "value, computed by concatenating the values of `x_block`.\n", + "\n", + "(Notice that we actually can't set `out_specs=P()` here. For technical\n", + "reasons related to automatic differentiation, we consider the output of\n", + "`all_gather` not to be guaranteed invariant across devices. If we wanted it to\n", + "be guaranteed invariant, we could use `jax.lax.all_gather_invariant`, or in\n", + "this case we could just avoid doing the `all_gather` in the function body and\n", + "instead just use `out_specs=P('i')` to perform the concatenation.)\n", + "\n", + "When `tiled=False` (the default), results are stacked along a new axis instead\n", + "of concatenated:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7d660032", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "def f5(x_block):\n", + " print('BEFORE:\\n', x_block)\n", + " y_block = jax.lax.all_gather(x_block, 'i', tiled=False)\n", + " print('AFTER:\\n', y_block)\n", + " return y_block\n", + "\n", + "y = f5(x)\n", + "print('FINAL RESULT:\\n', y)" + ] + }, + { + "cell_type": "markdown", + "id": "960bc5c3", + "metadata": {}, + "source": [ + "We could write the `collective_ref` reference semantics function for\n", + "`all_gather` as\n", + "\n", + "```python\n", + "def all_gather_ref(_, x_blocks, *, tiled=False):\n", + " combine = jnp.concatenate if tiled else jnp.stack\n", + " return [combine(x_blocks)] * len(x_blocks)\n", + "```\n", + "\n", + "In deep learning, we might use `all_gather`s on parameters in fully sharded\n", + "data parallelism (FSDP).\n", + "\n", + "# psum_scatter\n", + "\n", + "The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like\n", + "`psum` except each function instance gets only one shard of the result:\n", + "\n", + "\"Illustration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fcc18ee8", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "def f6(x_block):\n", + " print('BEFORE:\\n', x_block)\n", + " y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)\n", + " print('AFTER:\\n', y_block)\n", + " return y_block\n", + "\n", + "x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])\n", + "y = f6(x)\n", + "print('FINAL RESULT:\\n', y)" + ] + }, + { + "cell_type": "markdown", + "id": "b645b6b2", + "metadata": {}, + "source": [ + "As shown by the prints, each resulting `y_block` has a smaller size than the\n", + "argument `x_block`, unlike with `psum`. Moreover, compared to `psum`, here\n", + "each `y_block` only represents a slice of the sum of the `x_block`s across\n", + "function instances. (Even though each function instance gets only one shard of\n", + "the sum, the final output `y` is the same as in the `psum` example because\n", + "here we use `out_specs=P('i')` to concatenate each function instance's\n", + "output.)\n", + "\n", + "In terms of what values are computed, a `collective_ref` reference\n", + "implementation might look like:\n", + "\n", + "```python\n", + "def psum_scatter_ref(i, x_blocks, *, tiled=False):\n", + " axis_size = len(x_blocks)\n", + " tot = sum(x_blocks)\n", + " if tiled:\n", + " tot = tot.reshape(axis_size, -1, *tot.shape[1:]) # split leading axis\n", + " return [tot[i] for i in range(tot.shape[0])]\n", + "```\n", + "\n", + "It's not captured in the semantics reference implementation, but\n", + "`psum_scatter` is useful because these results can be computed more\n", + "efficiently, with less communication, than a full `psum`. In fact, one way to\n", + "think of `psum_scatter` is as \"the first half of a `psum`, before an\n", + "`all_gather`\". That is, one way to implement `psum` is:\n", + "\n", + "```python\n", + "def psum(x, axis_name):\n", + " summed_chunk = jax.lax.psum_scatter(x, axis_name)\n", + " return jax.lax.all_gather(summed_chunk, axis_name)\n", + "```\n", + "\n", + "Indeed, this implementation is often used on both TPU and GPU!\n", + "\n", + "The reason `psum_scatter` can require about half the communication as a full\n", + "`psum` is illustrated the `ppermute` section.\n", + "\n", + "Another intuition is that we can use `psum_scatter` to implement a distributed\n", + "matrix multiplication with inputs and outputs sharded over the same axis. In\n", + "machine learning, `psum_scatter` can be used in tensor-parallel matrix\n", + "multiplies or fully-sharded data parallel gradient accumulation, as shown in\n", + "the examples to follow.\n", + "\n", + "## `ppermute`\n", + "\n", + "The `jax.lax.ppermute` collective provides the most direct way for\n", + "function instances to send data to one another. Given a mesh axis and a\n", + "list of `(source_index, destination_index)` pairs representing indices along\n", + "that mesh axis, `ppermute` sends its argument value from each source function\n", + "instance to each destination:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc06e7b4", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "def f7(x_block):\n", + " sz = jax.lax.psum(1, 'i')\n", + " print('BEFORE:\\n', x_block)\n", + " y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])\n", + " print('AFTER:\\n', y_block)\n", + " return y_block\n", + "\n", + "y = f7(jnp.arange(8))\n", + "print('FINAL RESULT:\\n', y)" + ] + }, + { + "cell_type": "markdown", + "id": "0f825fd1", + "metadata": {}, + "source": [ + "In this case, with just two function instances, each instance's value of\n", + "`y_block` is the other's value of `x_block`.\n", + "\n", + "Source indices and destination indices can't be repeated. If an index does not\n", + "appear as a destination, then the value of the corresponding function\n", + "instance's result is an array of zeros.\n", + "\n", + "A `collective_ref` reference implementation could look like\n", + "\n", + "```python\n", + "def ppermute_ref(i, x_blocks, perm):\n", + " results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks)\n", + " for src, dst in perm:\n", + " results[dst] = x_blocks[src]\n", + " return results\n", + "```\n", + "\n", + "Other collectives can be implemented efficiently, in terms of total\n", + "communication, using `ppermute`s where each function passes data only to its\n", + "neighbors. For example, we could implement `psum_scatter` using a sequence of\n", + "`ppermute`s and local additions this way:\n", + "\n", + "\"Illustration\n", + "\n", + "Or, with a numerical example:\n", + "\n", + "\"Illustration\n", + "\n", + "\n", + "Intuitively, on each iteration each function instance sends 'up' the value it\n", + "received on the previous iteration, and reduces (adds) the value it receives\n", + "this iteration. In code, it might look like this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b7aa515", + "metadata": {}, + "outputs": [], + "source": [ + "def psum_scatter(x, axis_name, *, tiled=False):\n", + " size = jax.lax.psum(1, axis_name)\n", + " idx = jax.lax.axis_index(axis_name) # function instance index along axis_name\n", + " if tiled:\n", + " x = x.reshape(size, -1, *x.shape[1:]) # split leading axis\n", + " shift = partial(jax.lax.ppermute, axis_name=axis_name,\n", + " perm=[(i, (i - 1) % size) for i in range(size)])\n", + " for i in range(1, size):\n", + " update = shift(x[(idx + i) % size])\n", + " x = x.at[(idx + i + 1) % size].add(update)\n", + " return x[idx]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77cc7c7c", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "def f8(x_block):\n", + " print('BEFORE:\\n', x_block)\n", + " y_block = psum_scatter(x_block, 'i', tiled=True)\n", + " print('AFTER:\\n', y_block)\n", + " return y_block\n", + "\n", + "x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])\n", + "y = f8(x)\n", + "print('FINAL RESULT:\\n', y)" + ] + }, + { + "cell_type": "markdown", + "id": "61400a38", + "metadata": {}, + "source": [ + "On TPU, there are higher-dimensional variants of this algorithm to exploit\n", + "multiple bidirectional physical mesh axes.\n", + "\n", + "Notice that `psum_scatter` is the transpose of `all_gather`. Indeed, a way to\n", + "implement `all_gather` in terms of `ppermute` looks like the reverse of the\n", + "above process:\n", + "\n", + "\"Illustration\n", + "\n", + "In deep learning, we might use `ppermute` when implementing SPMD pipeline\n", + "parallelism, where we divide our network along its depth into stages and\n", + "evaluate the applications of stages in parallel. Or we might use `ppermute` in\n", + "parallelizing the evaluation of convolutional layers, where we shard over\n", + "spatial axes and thus devices must communicate \"halos\" to each other. Or it\n", + "may be used under-the-hood in tensor-parallel matrix multiplies.\n", + "\n", + "## `all_to_all`\n", + "\n", + "A final collective is `all_to_all`, which is essentially a block matrix\n", + "transpose operating along one positional axis and one cross-device axis:\n", + "\n", + "\"Illustration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6fa39069", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "def f9(x_block):\n", + " print('BEFORE:\\n', x_block)\n", + " y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,\n", + " tiled=True)\n", + " print('AFTER:\\n', y_block)\n", + " return y_block\n", + "\n", + "x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])\n", + "y = f9(x)\n", + "print('FINAL RESULT:\\n', y)" + ] + }, + { + "cell_type": "markdown", + "id": "549af5f6", + "metadata": {}, + "source": [ + "The `split_axis` argument indicates which positional axis should be sharded\n", + "and partitioned across the mesh axis. The `concat_axis` argument indicates the\n", + "axis along which the communicated results should be concatenated or stacked.\n", + "\n", + "When `tiled=False` (the default), the `split_axis` axis size must equal the\n", + "size of the mesh axis named `axis_name`, and a new axis of that size is\n", + "created at position `concat_axis` for the stacked results. When `tiled=True`,\n", + "the `split_axis` axis size need only be evenly divisible by the size of the\n", + "mesh axis, and results are concatenated along the existing axis `concat_axis`.\n", + "\n", + "The `collective_ref` reference semantics when `split_axis=0` and\n", + "`concat_axis=0` might look like:\n", + "\n", + "```python\n", + "def all_to_all_ref(_, x_blocks, *, tiled=False):\n", + " axis_size = len(x_blocks)\n", + " if tiled:\n", + " splits = [jnp.array_split(x, axis_size) for x in x_blocks]\n", + " return [jnp.concatenate(s) for s in zip(*splits)]\n", + " else:\n", + " splits = [list(x) for x in x_blocks]\n", + " return [jnp.stack(s) for s in zip(*splits)]\n", + "```\n", + "\n", + "In deep learning, we might use `all_to_all` in mixture-of-expert routing,\n", + "where we first sort our local batch of examples according to which expert they\n", + "should go to, then apply an `all_to_all` to redistribute examples to experts.\n", + "\n", + "# Toy examples\n", + "\n", + "How might we use `shard_map` and collective communication in practice? These\n", + "examples, while simple, give some idea.\n", + "\n", + "## Matrix multiplies\n", + "\n", + "Parallelizing matrix multiplication is central in scaling up deep learning\n", + "models, both for training and for inference. When `jax.jit` automatically\n", + "parallelizes matrix multiplication, it can use one of several different\n", + "strategies, depending on matrix sizes, hardware details, and other factors. How\n", + "might we write some of those parallelized routines more explicitly using\n", + "`shard_map`? And how can we optimize them to get better compute/communication\n", + "overlap and thus improve FLOP utilization?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86b85f05", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", + "from jax.experimental.shard_map import shard_map" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bcd9b561", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = Mesh(jax.devices()[:4], ('i',))\n", + "\n", + "def device_put(x, pspec):\n", + " return jax.device_put(x, NamedSharding(mesh, pspec))" + ] + }, + { + "cell_type": "markdown", + "id": "2e2b33b9", + "metadata": {}, + "source": [ + "### Example 1: `all-gather` on one side\n", + "\n", + "Consider performing a matrix multiplication where we shard the left-hand side\n", + "argument (can think: parameters) on its leading (non-contracting) dimension:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc221220", + "metadata": {}, + "outputs": [], + "source": [ + "lhs_spec = P('i', None)\n", + "lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)" + ] + }, + { + "cell_type": "markdown", + "id": "3e76bcfd", + "metadata": {}, + "source": [ + "And wee shard the right-hand side argument (can think: activations) on its\n", + "contracting dimension, with a similar sharding for the output:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "272bd303", + "metadata": {}, + "outputs": [], + "source": [ + "rhs_spec = P('i', None)\n", + "rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)" + ] + }, + { + "cell_type": "markdown", + "id": "2691cf9c", + "metadata": {}, + "source": [ + "To perform this matrix multiplication, we can first all-gather the right-hand\n", + "side and then perform local matrix multiplies against the sharded left-hand\n", + "side:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8971775e", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + " out_specs=rhs_spec)\n", + "def matmul_allgather(lhs_block, rhs_block):\n", + " rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)\n", + " return lhs_block @ rhs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a582e7ca", + "metadata": {}, + "outputs": [], + "source": [ + "out = matmul_allgather(lhs, rhs)\n", + "print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))" + ] + }, + { + "cell_type": "markdown", + "id": "41eb2743", + "metadata": {}, + "source": [ + "That's great, but we're not getting any compute/communication overlap\n", + "here: before we can start the matmul, we need the all_gather to complete.\n", + "Here's a profile using the same code, but on larger example shapes (`(8192,\n", + "8192)` for `lhs` and `(8192, 1024)` for `rhs`):\n", + "\n", + "\"Profile\n", + "\n", + "We can get compute/communication overlap if instead of calling `all_gather` we\n", + "basically inline our above implementation of `all_gather` in terms of\n", + "`ppermute`, then interleave steps of the gather permutation with local matrix\n", + "multiplies:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a6b952b", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + " out_specs=rhs_spec)\n", + "def matmul_allgather_overlapped(lhs_block, rhs_block):\n", + " size = jax.lax.psum(1, 'i')\n", + " idx = jax.lax.axis_index('i')\n", + " shift = partial(jax.lax.ppermute, axis_name='i',\n", + " perm=[(i, (i + 1) % size) for i in range(size)])\n", + "\n", + " B = lhs_block.shape[1] // size\n", + " lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)\n", + "\n", + " out_block = lhs_blocks(idx) @ rhs_block\n", + " for i in range(1, size):\n", + " rhs_block = shift(rhs_block)\n", + " out_block += lhs_blocks((idx - i) % size) @ rhs_block\n", + " return out_block" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbe3cff0", + "metadata": {}, + "outputs": [], + "source": [ + "out = matmul_allgather_overlapped(lhs, rhs)\n", + "print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))" + ] + }, + { + "cell_type": "markdown", + "id": "5d683ab3", + "metadata": {}, + "source": [ + "This implementation allows overlap between communication and computation, and\n", + "also avoids gathering a large intermediate onto each device. But on TPU it uses\n", + "only half the interconnect bandwidth by permuting in only one direction along\n", + "the ring. To permute bidirectionally, we just split the blocks in half and send\n", + "each half in each direction:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b777f21d", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + " out_specs=rhs_spec)\n", + "def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):\n", + " size = jax.lax.psum(1, 'i')\n", + " idx = jax.lax.axis_index('i')\n", + " shift_up = partial(jax.lax.ppermute, axis_name='i',\n", + " perm=[(i, (i + 1) % size) for i in range(size)])\n", + " shift_dn = partial(jax.lax.ppermute, axis_name='i',\n", + " perm=[(i, (i - 1) % size) for i in range(size)])\n", + "\n", + " B = lhs_block.shape[1] // size // 2 # half-size blocks\n", + " lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1)\n", + "\n", + " rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)\n", + " out_block = lhs_blocks(idx, 0) @ rhs_block_lo\n", + " out_block += lhs_blocks(idx, 1) @ rhs_block_hi\n", + " for i in range(1, size):\n", + " rhs_block_lo = shift_up(rhs_block_lo)\n", + " rhs_block_hi = shift_dn(rhs_block_hi)\n", + " out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo\n", + " out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi\n", + " return out_block" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e40d1e8c", + "metadata": {}, + "outputs": [], + "source": [ + "out = matmul_allgather_overlapped_bidi(lhs, rhs)\n", + "print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))" + ] + }, + { + "cell_type": "markdown", + "id": "884f5535", + "metadata": {}, + "source": [ + "\"Profile\n", + "\n", + "In practice, to reduce compile times we would probably roll this into a\n", + "`jax.lax.fori_loop`. We might also have additional axes of parallelism\n", + "involved.\n", + "\n", + "### Example 2: `psum_scatter` the result\n", + "\n", + "Another sharding we might start with has both `lhs` and `rhs` sharded along\n", + "their contracting dimensions, with the output sharded like `rhs` again:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee46a15f", + "metadata": {}, + "outputs": [], + "source": [ + "lhs_spec = P(None, 'i')\n", + "lhs = device_put(lhs, lhs_spec)\n", + "\n", + "rhs_spec = P('i', None)\n", + "rhs = device_put(rhs, rhs_spec)" + ] + }, + { + "cell_type": "markdown", + "id": "c1c7200f", + "metadata": {}, + "source": [ + "Here we can use a `reduce_scatter` to perform the contraction sum over shards:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05d6ad68", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + " out_specs=rhs_spec)\n", + "def matmul_psumscatter(lhs_block, rhs_block):\n", + " out_summand = lhs_block @ rhs_block\n", + " return jax.lax.psum_scatter(out_summand, 'i', tiled=True)\n", + "\n", + "out = matmul_psumscatter(lhs, rhs)\n", + "print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))" + ] + }, + { + "cell_type": "markdown", + "id": "cad74c82", + "metadata": {}, + "source": [ + "But the scattering communication must wait for the entire local matrix multiply\n", + "to finish before it can start. To get communication/computation overlap, we can\n", + "inline an implementation of `psum_scatter` in terms of `ppermute`, then\n", + "interleave the communication steps with local matrix multiplies:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66dfac2d", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + " out_specs=rhs_spec)\n", + "def matmul_psumscatter_overlapped(lhs_block, rhs_block):\n", + " size = jax.lax.psum(1, 'i')\n", + " idx = jax.lax.axis_index('i')\n", + " shift = partial(jax.lax.ppermute, axis_name='i',\n", + " perm=[(i, (i - 1) % size) for i in range(size)])\n", + " lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1]) # split 1st axis\n", + "\n", + " out_summand = lhs_block[(idx + 1) % size] @ rhs_block\n", + " for i in range(1, size):\n", + " out_summand = shift(out_summand)\n", + " out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block\n", + " return out_summand" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "042dde0a", + "metadata": {}, + "outputs": [], + "source": [ + "out = matmul_psumscatter_overlapped(lhs, rhs)\n", + "print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))" + ] + }, + { + "cell_type": "markdown", + "id": "8c73f21e", + "metadata": {}, + "source": [ + "As in the previous example, to fully utilize interconnects on TPU, we'd run a\n", + "bidirectional version:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52db4d9e", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + " out_specs=rhs_spec)\n", + "def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):\n", + " size = jax.lax.psum(1, 'i')\n", + " idx = jax.lax.axis_index('i')\n", + " shift_up = partial(jax.lax.ppermute, axis_name='i',\n", + " perm=[(i, (i + 1) % size) for i in range(size)])\n", + " shift_dn = partial(jax.lax.ppermute, axis_name='i',\n", + " perm=[(i, (i - 1) % size) for i in range(size)])\n", + "\n", + " B = lhs_block.shape[0] // size // 2 # half-size blocks\n", + " lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0)\n", + "\n", + " out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block\n", + " out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block\n", + " for i in range(1, size):\n", + " out_summand_lo = shift_up(out_summand_lo)\n", + " out_summand_hi = shift_dn(out_summand_hi)\n", + " out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block\n", + " out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block\n", + " return jnp.concatenate([out_summand_lo, out_summand_hi])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c29971e8", + "metadata": {}, + "outputs": [], + "source": [ + "out = matmul_psumscatter_overlapped_bidi(lhs, rhs)\n", + "print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))" + ] + }, + { + "cell_type": "markdown", + "id": "60c2d2bc", + "metadata": {}, + "source": [ + "## Neural networks\n", + "\n", + "We can use `shard_map` to parallelize computation in neural networks, either by\n", + "itself or in combination with the automatic partitioning in `jax.jit`. This\n", + "section has a few examples based on this toy neural network and random data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "981ad73a", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "def predict(params, inputs):\n", + " for W, b in params:\n", + " outputs = jnp.dot(inputs, W) + b\n", + " inputs = jax.nn.relu(outputs)\n", + " return outputs\n", + "\n", + "def loss(params, batch):\n", + " inputs, targets = batch\n", + " predictions = predict(params, inputs)\n", + " return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6652af0", + "metadata": {}, + "outputs": [], + "source": [ + "def init_layer(key, n_in, n_out):\n", + " k1, k2 = jax.random.split(key)\n", + " W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n", + " b = jax.random.normal(k2, (n_out,))\n", + " return W, b\n", + "\n", + "def init(key, layer_sizes, batch_size):\n", + " key, *keys = jax.random.split(key, len(layer_sizes))\n", + " params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n", + "\n", + " key, *keys = jax.random.split(key, 3)\n", + " inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n", + " targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n", + "\n", + " return params, (inputs, targets)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1db636cc", + "metadata": {}, + "outputs": [], + "source": [ + "layer_sizes = [784, 128, 128, 128, 128, 128, 8]\n", + "batch_size = 32\n", + "\n", + "params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)" + ] + }, + { + "cell_type": "markdown", + "id": "43f9e760", + "metadata": {}, + "source": [ + "Compare these examples with the purely [automatic partitioning examples in the\n", + "\"Distributed arrays and automatic partitioning\"\n", + "doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", + "While in those automatic partitioning examples we don't need to edit the model\n", + "functions to use different parallelization strategies, with `shard_map` we\n", + "often do.\n", + "\n", + "### 8-way batch data parallelism\n", + "\n", + "The simplest multi-device parallelism strategy is to shard the batch of inputs\n", + "and targets over multiple devices, replicate the parameters over those devices,\n", + "and apply the model in parallel to those shards of data. To evaluate the total\n", + "loss, the devices need only communicate with a scalar-sized all-reduce-sum at\n", + "the end. (To evaluate the gradient of the loss, the devices must perform\n", + "all-reduce-sums of parameter gradients in the backward pass.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6417125", + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "from jax.sharding import NamedSharding, Mesh, PartitionSpec as P\n", + "from jax.experimental.shard_map import shard_map\n", + "from jax.experimental import mesh_utils\n", + "\n", + "devices = mesh_utils.create_device_mesh((8,))\n", + "\n", + "# replicate initial params on all devices, shard data batch over devices\n", + "mesh = Mesh(devices, ('batch',))\n", + "batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n", + "params = jax.device_put(params, NamedSharding(mesh, P()))\n", + "\n", + "# adapt the loss function to sum the losses across devices\n", + "def loss_dp(params, batch):\n", + " @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P(),\n", + " check_rep=False) # TODO remove check_rep=False\n", + " def loss_spmd(local_batch):\n", + " inputs, targets = local_batch\n", + " predictions = predict(params, inputs) # use reference 'predict`\n", + " local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))\n", + " return jax.lax.pmean(local_loss, 'batch')\n", + " return loss_spmd(batch)" + ] + }, + { + "cell_type": "markdown", + "id": "9fe63185", + "metadata": {}, + "source": [ + "We can check that the loss and its gradients match the reference (base) model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f1a7155", + "metadata": {}, + "outputs": [], + "source": [ + "print(jax.jit(loss)(params, batch))\n", + "print(jax.jit(loss_dp)(params, batch))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89e0a24d", + "metadata": {}, + "outputs": [], + "source": [ + "def allclose(a, b):\n", + " return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))\n", + "\n", + "print(allclose(jax.jit(jax.grad(loss))(params, batch),\n", + " jax.jit(jax.grad(loss_dp))(params, batch)))" + ] + }, + { + "cell_type": "markdown", + "id": "0033ac3b", + "metadata": {}, + "source": [ + "We can print the compiler IR to inspect the gradient computation and verify\n", + "that the collective all-reduce-sum operations happen where we'd expect: at the\n", + "end of the forward pass to compute the loss value, and in the backward pass to\n", + "compute the total parameter gradients.\n", + "\n", + "### 8-way fully sharded data parallelism (FSDP)\n", + "\n", + "Another strategy is to additionally shard the parameters over the devices,\n", + "all-gathering each one when the full value is needed for the `jnp.dot` or bias\n", + "addition. Since we only have one full parameter in local device memory at a\n", + "time, rather than keeping all parameters in all device memories as in the\n", + "preceding DP example, we free up significant memory that we can use for larger\n", + "models or larger batch sizes. And because XLA will overlap computation and\n", + "inter-device communication, the wall-clock time doesn't suffer.\n", + "\n", + "So now we need collectives in two places: the model prediction function\n", + "`predict` needs to all-gather the parameters before they're used, and as in the\n", + "DP case the loss function needs to sum the local losses to compute the total\n", + "loss.\n", + "\n", + "There's one other ingredient we need: we don't want to store the fully gathered\n", + "parameters from the forward pass for use on the backward pass. Instead, we want\n", + "to gather them again on the backward pass. We can express that by using\n", + "`jax.remat` with a [custom\n", + "policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)\n", + "(or a `custom_vjp`), though XLA typically does that rematerialization\n", + "automatically.\n", + "\n", + "This general [FSDP\n", + "approach](https://engineering.fb.com/2021/07/15/open-source/fsdp/) is similar\n", + "to [weight update sharding (WUS)](https://arxiv.org/abs/2004.13336) and\n", + "[ZeRO-3](https://arxiv.org/abs/1910.02054)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4538cd6", + "metadata": {}, + "outputs": [], + "source": [ + "# shard data batch *and params* over devices\n", + "mesh = Mesh(devices, ('batch',))\n", + "batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n", + "params = jax.device_put(params, NamedSharding(mesh, P('batch')))\n", + "\n", + "# adapt the prediction function to gather weights just before their use,\n", + "# and to re-gather them on the backward pass (rather than saving them)\n", + "@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')\n", + "def predict_fsdp(params_frag, inputs):\n", + " for W_frag, b_frag in params_frag:\n", + " W = jax.lax.all_gather(W_frag, 'batch', tiled=True)\n", + " b = jax.lax.all_gather(b_frag, 'batch', tiled=True)\n", + " outputs = jnp.dot(inputs, W) + b\n", + " inputs = jax.nn.relu(outputs)\n", + " return outputs\n", + "\n", + "def loss_fsdp(params, batch):\n", + " @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())\n", + " def loss_spmd(local_params, local_batch):\n", + " inputs, targets = local_batch\n", + " predictions = predict_fsdp(local_params, inputs)\n", + " local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))\n", + " return jax.lax.pmean(local_loss, 'batch')\n", + " return loss_spmd(params, batch)" + ] + }, + { + "cell_type": "markdown", + "id": "30dd5f99", + "metadata": {}, + "source": [ + "Again we can check that the loss and its gradients match the reference model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1035b6fc", + "metadata": {}, + "outputs": [], + "source": [ + "print(jax.jit(loss)(params, batch))\n", + "print(jax.jit(loss_fsdp)(params, batch))\n", + "\n", + "print(allclose(jax.jit(jax.grad(loss))(params, batch),\n", + " jax.jit(jax.grad(loss_fsdp))(params, batch)))" + ] + }, + { + "cell_type": "markdown", + "id": "f88ddefe", + "metadata": {}, + "source": [ + "### 8-way tensor parallelism (TP)\n", + "\n", + "Usually we don't use tensor model parallelism by itself, but seeing it in\n", + "isolation is a good warmup on parallel matrix multiplication. It's also a good\n", + "example of using `shard_map` in a library function, called in a larger\n", + "`jit`-based computation.\n", + "\n", + "The parallelization idea is that we'll keep the data/activations sharded over\n", + "its feature axis (rather than its batch axis), and we'll similarly shard weight\n", + "matrices over their input-feature axis (and biases over their feature axis).\n", + "Then to perform the parallel matrix multiplication, we'll perform local matrix\n", + "multiplications followed by a `psum_scatter` to sum the local results and\n", + "efficiently scatter the result's shards." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bd1fb92", + "metadata": {}, + "outputs": [], + "source": [ + "devices = mesh_utils.create_device_mesh((8,))\n", + "mesh = Mesh(devices, ('feats',))\n", + "\n", + "batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))\n", + "params = jax.device_put(params, NamedSharding(mesh, P('feats')))\n", + "\n", + "def predict_tp(params, inputs):\n", + " for W, b in params:\n", + " outputs = gemm_tp(inputs, W, b)\n", + " inputs = jax.nn.relu(outputs)\n", + " return outputs\n", + "\n", + "@partial(shard_map, mesh=mesh,\n", + " in_specs=(P(None, 'feats'), P('feats', None), P('feats')),\n", + " out_specs=P(None, 'feats'))\n", + "def gemm_tp(inputs, W, b):\n", + " block_result = jnp.dot(inputs, W)\n", + " return jax.lax.psum_scatter(block_result, 'feats',\n", + " scatter_dimension=1, tiled=True) + b\n", + "\n", + "def loss_tp(params, batch):\n", + " inputs, targets = batch\n", + " predictions = predict_tp(params, inputs)\n", + " return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum!" + ] + }, + { + "cell_type": "markdown", + "id": "cf59d537", + "metadata": {}, + "source": [ + "### FSDP + TP, with `shard_map` at the top level\n", + "\n", + "We can compose these strategies together, using multiple axes of parallelism." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d4605705", + "metadata": {}, + "outputs": [], + "source": [ + "devices = mesh_utils.create_device_mesh((4, 2))\n", + "mesh = Mesh(devices, ('batch', 'feats'))\n", + "\n", + "batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))\n", + "params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))\n", + "\n", + "# mostly same as previous predict_fsdp definition, except we call gemm_tp\n", + "@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')\n", + "def predict_fsdp_tp(params_frag, inputs):\n", + " for W_frag, b_frag in params_frag:\n", + " W = jax.lax.all_gather(W_frag, 'batch', tiled=True)\n", + " b = jax.lax.all_gather(b_frag, 'batch', tiled=True)\n", + " block_result = jnp.dot(inputs, W)\n", + " outputs = jax.lax.psum_scatter(block_result, 'feats',\n", + " scatter_dimension=1, tiled=True) + b\n", + " inputs = jax.nn.relu(outputs)\n", + " return outputs\n", + "\n", + "@partial(shard_map, mesh=mesh,\n", + " in_specs=(P(('feats', 'batch')), P('batch', 'feats')))\n", + "def loss_fsdp_tp(local_params, local_batch):\n", + " inputs, targets = local_batch\n", + " predictions = predict_fsdp_tp(local_params, inputs)\n", + " sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats')\n", + " return jax.lax.pmean(jnp.mean(sq_err), 'batch')" + ] + }, + { + "cell_type": "markdown", + "id": "8220f75c", + "metadata": {}, + "source": [ + "Notice how we have to do _two_ collective reductions: one over `'feats'` and\n", + "one over `'batch'`. In the pure TP example, we didn't write the `'feats'`\n", + "reduction explicitly because we only used `shard_map` within `gemm_tp`; in the\n", + "caller `loss_tp`, the compiler automatically translated our use of `jnp.sum` to\n", + "perform a `psum` as needed given the sharded result returned by `predict_tp`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30b87c53", + "metadata": {}, + "outputs": [], + "source": [ + "print(jax.jit(loss)(params, batch))\n", + "print(jax.jit(loss_fsdp_tp)(params_, batch_))\n", + "\n", + "print(allclose(jax.jit(jax.grad(loss))(params, batch),\n", + " jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))" + ] + }, + { + "cell_type": "markdown", + "id": "94a352ca", + "metadata": {}, + "source": [ + "### SPMD pipeline parallelism (PP)\n", + "\n", + "With pipeline parallelism we aim to parallelize the evaluation of layers at\n", + "different depths in our network. For example, one device might compute the\n", + "application of the first layer while another device computes the application of\n", + "the second; when they finish, the first device passes its results to the second\n", + "while the second passes its results to the device responsible for the third\n", + "layer, and the process repeats. In general the number of pipeline stages may be\n", + "different from the number of layers, as each stage may be responsible for\n", + "multiple layers.\n", + "\n", + "With SPMD pipelining, we exploit the fact that most layers in the network apply\n", + "the computation, just with different parameter values. In particular, we can\n", + "stack together all the parameters except for those for the first and last\n", + "layers, then use a `shard_map` to map over blocks of those layer parameters,\n", + "where each block of parameters corresponds to a pipeline stage. We then use the\n", + "`jax.lax.ppermute` collective to shift data down the parallel pipeline.\n", + "\n", + "This particular pipelining strategy is essentially [the GPipe\n", + "strategy](https://arxiv.org/abs/1811.06965). There are several variants, as\n", + "well as quite different strategies, and which is appropriate can depend on the\n", + "speed of the networking between stages and batch sizes. But for this tutorial\n", + "we'll focus on just one strategy.\n", + "\n", + "First, we choose some pipeline parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec88fcdb", + "metadata": {}, + "outputs": [], + "source": [ + "L = len(params) - 2 # num layers, excluding first and last\n", + "N = batch_size # batch size\n", + "F = params[0][0].shape[1] # num features\n", + "\n", + "# choose some pipeline parameters\n", + "S = 2 # number of stages\n", + "B = 8 # size of each microbatch\n", + "assert L % S == 0, \"S (number of stages) must divide L (number of inner layers)\"\n", + "\n", + "# compute some useful quantities\n", + "M, ragged = divmod(N, B) # M is number of microbatches\n", + "assert not ragged, \"B (size of each microbatch) must divide total batch size\"\n", + "K, ragged = divmod(M, S) # K is microbatches per stage\n", + "assert not ragged, \"S (number of stages) must divide number of microbatches\"\n", + "print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total')\n", + "print(f'{B} examples per microbatch, {M} microbatches total')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e35bd0d6", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = Mesh(jax.devices()[:S], ('stages',))\n", + "\n", + "def predict_pp(params, inputs):\n", + " (W_first, b_first), inner_params, (W_last, b_last) = params\n", + " inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first)\n", + " inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]),\n", + " inner_params, inputs)\n", + " outputs = jnp.dot(inputs, W_last) + b_last\n", + " return outputs\n", + "\n", + "@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),\n", + " out_specs=P())\n", + "def loss_pp(params, batch):\n", + " inputs, targets = batch\n", + " predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)\n", + " local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))\n", + " return jax.lax.pmean(local_loss, 'stages')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "257b88e1", + "metadata": {}, + "outputs": [], + "source": [ + "def spmd_pipeline(fn, stage_params, inputs):\n", + " stage = jax.lax.axis_index('stages')\n", + " outputs = jnp.zeros_like(inputs) * jnp.nan\n", + " state = jnp.zeros((L // S, B, F)) * jnp.nan\n", + " for i in range(M+L-1):\n", + " state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0]))\n", + " state = jax.vmap(fn)(stage_params, state)\n", + " outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K]))\n", + " state, inputs, outputs = shift(i, state, inputs, outputs)\n", + " outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)])\n", + " return outputs\n", + "\n", + "def shift(i, state, inputs, outputs):\n", + " sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)])\n", + " state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1))\n", + " if (i % K) == (-1 % K):\n", + " inputs = sh(inputs, +1)\n", + " if ((i-L+1) % K) == (-1 % K):\n", + " outputs = sh(outputs, +1)\n", + " return state, inputs, outputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12a478a3", + "metadata": {}, + "outputs": [], + "source": [ + "first_params, *inner_params, last_params = params\n", + "Ws, bs = zip(*inner_params)\n", + "params_stacked = jnp.stack(Ws), jnp.stack(bs)\n", + "first_params = jax.device_put(first_params, NamedSharding(mesh, P()))\n", + "params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages')))\n", + "last_params = jax.device_put(last_params, NamedSharding(mesh, P()))\n", + "params_ = first_params, params_stacked, last_params\n", + "\n", + "batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a3086fb", + "metadata": {}, + "outputs": [], + "source": [ + "print(jax.jit(loss)(params, batch))\n", + "print(jax.jit(loss_pp)(params_, batch_))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9308c874", + "metadata": {}, + "outputs": [], + "source": [ + "_ = jax.jit(jax.grad(loss_pp))(params_, batch_) # don't crash" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "formats": "ipynb,md:myst", + "main_language": "python" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md new file mode 100644 index 000000000..9a8ac1bf9 --- /dev/null +++ b/docs/notebooks/shard_map.md @@ -0,0 +1,1373 @@ +--- +jupytext: + cell_metadata_filter: -all + formats: ipynb,md:myst + main_language: python + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.0 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +# Intro + +`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations. + +`shard_map` is complementary to, and comopsable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. + +If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this)) + +By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies. + +## So, let's see a `shard_map`! + +Without further ado, here's a toy example: + +```{code-cell} +from functools import partial + +import jax +import jax.numpy as jnp + +from jax.sharding import Mesh, PartitionSpec as P +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map +``` + +```{code-cell} +devices = mesh_utils.create_device_mesh((4, 2)) +mesh = Mesh(devices, axis_names=('x', 'y')) + +a = jnp.arange( 8 * 16.).reshape(8, 16) +b = jnp.arange(16 * 4.).reshape(16, 4) + +@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), + out_specs=P('x', None)) +def matmul_basic(a_block, b_block): + # a_block: f32[2, 8] + # b_block: f32[8, 4] + c_partialsum = jnp.dot(a_block, b_block) + c_block = jax.lax.psum(c_partialsum, 'y') + # c_block: f32[2, 4] + return c_block + +c = matmul_basic(a, b) # c: f32[8, 4] +``` + +This function computes a matrix multiply in parallel by performing local block matrix multipiles followed by a collective sum operation. We can check the result is correct: + +```{code-cell} +from jax.tree_util import tree_map, tree_all + +def allclose(a, b): + return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b)) + +allclose(c, jnp.dot(a, b)) +``` + +The result is sharded along its rows: + +```{code-cell} +jax.debug.visualize_array_sharding(c) +``` + +At a high level, `shard_map` is kind of like `vmap` or `pmap`, in that we're +mapping a function over pieces of array data, but notice that +* `shard_map` slices up inputs into blocks (and the output is formed by concatenating result blocks), keeping the rank the same, whereas `vmap` would reduce the rank by mapping away an axis; +* the `mesh` argument lets us control precise device placement of computation and results; +* we're mapping over multiple data axes at once, and setting up multiple axis names for collectives (both `'x'` and `'y'` here); +* since we're not using `jax.jit` yet, everything is eagerly evaluated, and we can even `print` intermediate values for debugging. + +The above code is performing the same computation as this `jax.jit` automatic parallelization code: + +```{code-cell} +from jax.sharding import NamedSharding + +a = jax.device_put(a, NamedSharding(mesh, P('x', 'y'))) +b = jax.device_put(b, NamedSharding(mesh, P('y', None))) + +@jax.jit +def matmul_reference(a, b): + c = jnp.dot(a, b) + return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None))) + +c_ref = matmul_reference(a, b) +allclose(c_ref, jnp.dot(a, b)) +``` + +We can think of `shard_map` as performing a `device_put` or +`with_sharding_constraint` on its inputs according to its `mesh` and `in_specs` +arguments, so the blocks over which `matmul_basic` operates are the same as in +`matmul_reference`: + +```{code-cell} +print('a blocks:'); jax.debug.visualize_array_sharding(a) +print('b blocks:'); jax.debug.visualize_array_sharding(b) +print('c blocks:'); jax.debug.visualize_array_sharding(c) +``` + +## Slow down, start with the basics! + +### Rank-reducing vs rank-preserving maps + +We can think of `vmap` and `pmap` as unstacking each array input along an axis +(e.g. unpacking a 2D matrix into its 1D rows), applying its body function to +each piece, and stacking the results back together, at least when collectives +aren't involved: + +```{code-cell} +def check_vmap(f, xs): + ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs) + expected = jnp.stack([f(x) for x in xs]) # vmap reference semantics + print(allclose(ans, expected)) + +check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3)) +``` + +For example, if `xs` had shape `f32[8,5]` then each `x` would have shape +`f32[5]`, and if each `f(x)` had shape `f32[3,7]` then the final stacked result +`vmap(f)(xs)` would have shape `f32[8,3,7]`. That is, each application of the +body function `f` takes as argument inputs with one fewer axis than the +corresponding argument to `vmap(f)`. We can say these are _rank-reducing maps_ +with unstacking/stacking of inputs/outputs. + +The number of logical applications of `f`, or _instances_ of `f`, is determined +by the size of the input axis being mapped over: for example, if we map over an +input axis of size 8, semantically we get 8 logical applications of the +function. + +In contrast, `shard_map` does not have this rank-reducing behavior. Instead, we +can think of it as slicing (or "unconcatenating") along input axes into blocks, +applying the body function, and concatenating the results back together (again +when collectives aren't involved): + +```{code-cell} +import numpy as np +devices = np.array(jax.devices()[:4]) +mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4 + +def check_shmap(f, y): + ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y) + expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])]) + print(allclose(ans, expected)) + +check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4)) +``` + +Recall that jnp.split slices its input into equally-sized blocks with the same +rank, so that if in the above example `y` had shape `f32[8,5]` then each +`y_blk` would have shape `f32[2,5]`, and if each `f(y_blk)` had shape +`f32[3,7]` then the final concatenated result `shard_map(f, ...)(y)` would have +shape `f32[12,7]`. So `shard_map` maps over _shards_, or blocks, of its inputs. +We can say it's a _rank-preserving map_ with unconcatenating/concatenating of +its inputs/outputs. + +The number of logical applications of `f` is determined by the mesh size, not +by any input axis size: for example, if we have a mesh of total size 4 (i.e. +over 4 devices) then semantically we get 4 logical applications of the +function, corresponding to the 4 devices physically computing them. + +### Controlling how each input is split (unconcatenated) and tiled with `in_specs` + +Each of the `in_specs` identifies some of the corresponding input array's axes +with mesh axes by name using `PartitionSpec`s, representing how to split (or +unconcatenate) that input into the blocks to which the body function is +applied. That identification determines the shard sizes; when an input axis is +identified with a mesh axis, the input is split (unconcatenated) along that +logical axis into a number of pieces equal to the corresponding mesh axis size. +(It's an error if the corresponding mesh axis size does not evenly divide the +input array axis size.) If an input's pspec does not mention a mesh axis name, +then there's no splitting over that mesh axis. For example: + +```{code-cell} +devices = mesh_utils.create_device_mesh((4, 2)) +mesh = Mesh(devices, ('i', 'j')) + +@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j')) +def f1(x_block): + print(x_block.shape) # prints (3, 12) + return x_block + +x1 = jnp.arange(12 * 12).reshape(12, 12) +y = f1(x1) +``` + +Here, because the input pspec did not mention the mesh axis name `'j'`, no +input array axis is split over that mesh axis; similarly, because the second +axis of the input array is not identified with (and hence split over) any mesh +axis, application of `f1` gets a full view of the input along that axis. + +When a mesh axis is not mentioned in an input pspec, we can always rewrite to a +less efficient program where all mesh axes are mentioned but the caller +performs a `jnp.tile`, for example: + +```{code-cell} +@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j')) +def f2(x_block): + print(x_block.shape) + return x_block + +x = jnp.arange(12 * 12).reshape(12, 12) +x_ = jnp.tile(x, (1, mesh.shape['j'])) # x_ has shape (12, 24) +y = f2(x_) # prints (3,12), and f1(x) == f2(x_) +``` + +In other words, because each input pspec can mention each mesh axis name zero +or one times, rather than having to mention each name exactly once, we can say +that in addition to the `jnp.split` built into its input, `shard_map` also has +a `jnp.tile` built into its input, at least logically (though the tiling may +not need to be carried out physically, depending on the arguments' physical +sharding layout). The tiling to use is not unique; we could also have tiled +along the first axis, and used the pspec `P(('j', 'i'), None)`. + +Physical data movement is possible on inputs, as each device needs to have a +copy of the appropriate data. + +### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs` + +Analogously to the input side, each of the `out_specs` identifies some of the +corresponding output array's axes with mesh axes by name, representing how the +output blocks (one for each application of the body function, or equivalently +one for each physical device) should be assembled back together to form the +final output value. For example, in both the `f1` and `f2` examples above the +`out_specs` indicate we should form the final output by concatenating together +the block results along both axes, resulting in both cases an array `y` of +shape `(12, 24)`. (It's an error if an output shape of the body function, i.e. +an output block shape, has a rank too small for the concatenation described by +the corresponding output pspec.) + +When a mesh axis name is not mentioned in an output pspec, it represents an +un-tiling: when the user writes an output pspec which does not mention one of +the mesh axis names, they promise that the output blocks are equal along that +mesh axis, and so only one block along that axis is used in the output (rather +than concatenating all the blocks together along that mesh axis). For example, +using the same mesh as above: + +```{code-cell} +x = jnp.array([[3.]]) + +z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))() +print(z) # prints the same as jnp.tile(x, (4, 2)) + +z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))() +print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,)) + +z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))() +print(z) # prints the same as jnp.tile(x, (1, 1)), or just x +``` + +The body function closing over an array value is equivalent to passing it as an +augment with a corresponding input pspec of P(None, None). As another example, +following more closely to the other examples above: + +```{code-cell} +@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None)) +def f3(x_block): + return jax.lax.psum(x_block, 'j') + +x = jnp.arange(12 * 12).reshape(12, 12) +y3 = f3(x) +print(y3.shape) +``` + +The result has a second axis size of 6, half the size of the input's second +axis. In this case, the un-tile expressed by not mentioning the mesh axis name +`'j'` in the output pspec was safe because of the collective `psum`, which +ensures each output block is equal along the corresponding mesh axis. Here are +two more examples where we vary which mesh axes are mentioned in the output +pspec: + +```{code-cell} +@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j')) +def f4(x_block): + return jax.lax.psum(x_block, 'i') + +x = jnp.arange(12 * 12).reshape(12, 12) +y4 = f4(x) +print(y4.shape) # (3,12) + + +@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None)) +def f5(x_block): + return jax.lax.psum(x_block, ('i', 'j')) + +y5 = f5(x) +print(y5.shape) # (3,6) +``` + +On the physical side, not mentioning a mesh axis name in an output pspec +assembles an `Array` from the output device buffers with replicated layout +along that mesh axis. + +There is no runtime check that the output blocks are actually equal along a +mesh axis to be un-tiled along, or equivalently that the corresponding physical +buffers have equal values and thus can be interpreted as a replicated layout +for a single logical array. But we can provide a static check mechanism which +raises an error on all potentially-incorrect programs. + +Because the `out_specs` can mention mesh axis names zero or one times, and +because they can be mentioned in any order, we can say that in addition to the +`jnp.concatenate` built into its output, `shard_map` also has both an _untile_ +and a _block transpose_ built into its output. + +Physical data movement is not possible on outputs, no matter the output pspec. +Instead, `out_specs` just encodes how to assemble the block outputs into +`Array`s, or physically how to interpret the buffers across devices as the +physical layout of a single logical `Array`. + +# API Specification + +```python +from jax.sharding import Mesh +Specs = PyTree[PartitionSpec] + +def shard_map( + f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs, + auto: collections.abc.Set[AxisName] = frozenset([]), + check_rep: bool = True, +) -> Callable: + ... +``` +where: +* communication collectives like `psum` in the body of `f` can mention the axis names of `mesh`; +* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; +* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively; +* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the obdy, as in the caller, rather than manually; +* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)). + +The shapes of the arguments passed to `f` have the same ranks as the arguments +passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed +from the shape `shape` of the corresponding argument to `shard_map`-of-`f` and +the corresponding `PartitionSpec` `spec` as roughly +`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`. + +# Collectives tutorial + +A `shard_map` need not be a pure map: function applications can communicate +with each other via _collectives_, using axis names defined in the `mesh` +argument. + +Recall that `shard_map` maps a function over shards, or blocks, of input data, +so that this: + +```python +mesh = Mesh(jax.devices(), ('i',)) +x = jnp.arange(16.) +f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i')) +y = f_shmapped(x) +``` + +Computes the same values, evaluating applications of `f` to the same argument +values, as this reference function: + +```python +def f_shmapped_ref(x): + x_blocks = jnp.array_split(x, mesh.shape[0]) + y_blocks = [f(x_blk) for x_blk in x_blocks] + return jnp.concatenate(y_blocks) +``` + +We call these applications of `f` to different argument shards _function +instances_. Each function instance is executed on a different device (or subset +of devices). + +These reference semantics work when `f` has no communication collectives in +it. But what if we want the function instances to communicate, corresponding +to having cross-device communication? That is, what are the reference +semantics when `f` contains a collective? Say `f` has just one collective, and +is of the form + +```python +def f(x_blk): + z_blk = f_part1(x_blk) + u_blk = collective(z_blk, axis_name) + v_blk = f_part2(x_blk, z_blk, u_blk) + return v_blk +``` + +where we're assuming there's only one mesh axis we're mapping over, and +`axis_name` is the corresponding name for it. Then the reference semantics +would look more like: + +```python +def f_shmapped_ref(x): + x_blocks = jnp.array_split(x, mesh.shape[0]) + z_blocks = [f_part1(x_blk) for x_blk in x_blocks] + u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))] + v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk + in zip(x_blocks, z_blocks, u_blocks)] + return jnp.concatenate(v_blocks) +``` + +Notice that `collective_ref` might depend on all the `z_blocks`. That is, +while `f_part1` and `f_part2` are mapped over blocks independently, a +collective introduces some amount of cross-block dependence. Physically, that +means communication across devices. Exactly what communication happens, and +what values are computed, depend on the collective. + +## `psum` + +The simplest collective may be `jax.lax.psum`, which computes an +all-reduce-sum along a device mesh axis (or multiple axes). +Here's a toy example: + +Illustration of a psum computation. + +```{code-cell} +import jax +import jax.numpy as jnp +from jax import lax + +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from jax.experimental.shard_map import shard_map +``` + +```{code-cell} +mesh1d = Mesh(jax.devices()[:4], ('i',)) + +@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None)) +def f1(x_block): + print('BEFORE:\n', x_block) + y_block = jax.lax.psum(x_block, 'i') + print('AFTER:\n', y_block) + return y_block +``` + +```{code-cell} +x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2]) +y = f1(x) +print('FINAL RESULT:\n', y) +``` + +The prints show that each function application starts with its own chunk of +the argument value `x_block`. After the `psum`, each function application has +the same value of `y_block`, computed by summing the applications' `x_block` +values together. + +In the case where there's a single axis name in the computation, we could say +that the `collective_ref` reference implementation for `psum` is + +```python +def psum_ref(_, x_blocks): + tot = sum(x_blocks) + return [tot] * len(x_blocks) +``` + +Notice also that because `f1` returns `y_block`, the result of a `psum` over +`'i'`, we can use `out_specs=P()` so the caller gets a single logical copy of +the result value, rather than a tiled result. + +When there is more than one mesh axis, we can perform a `psum` over +each one separately, or over multiple axes at once: + +```{code-cell} +mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j')) + +@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j')) +def f2(x_block): + print('BEFORE:\n', x_block) + y_block = jax.lax.psum(x_block, 'i') + print('AFTER:\n', y_block) + return y_block + +y = f2(jnp.arange(16).reshape(4, 4)) +print('FINAL RESULT:\n', y) +``` + +By applying a `psum` over mesh axis `'i'`, we get values of `y_block` which +are equal along axis '`i'`, but not axis `'j'`. (So we can use +`out_specs=P(None, 'j')` to get a single logical result along that axis.) + +If we apply the `psum` over both axes, the `y_block` value is equal along both +axes: + +```{code-cell} +@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None)) +def f3(x_block): + print('BEFORE:\n', x_block) + y_block = jax.lax.psum(x_block, ('i', 'j')) + print('AFTER:\n', y_block) + return y_block + +y = f3(jnp.arange(16).reshape(4, 4)) +print('FINAL RESULT:\n', y) +``` + +In machine learning, we often use `psum` to compute total losses or, when we +have a `grad` inside the `shard_map`ped function body, total gradients. + +In the sequel, we'll see how `psum` can be implemented in terms of other +primitives, which gives some intuition about its communication cost. + +## `all_gather` + +Another fundamental operation is gathering array shards along an axis, so that +each function application has a full copy of the data along that axis: + +Illustration of an all_gather computation. + +```{code-cell} +@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +def f4(x_block): + print('BEFORE:\n', x_block) + y_block = jax.lax.all_gather(x_block, 'i', tiled=True) + print('AFTER:\n', y_block) + return y_block + +x = jnp.array([3, 9, 5, 2]) +y = f4(x) +print('FINAL RESULT:\n', y) +``` + +The prints show that each function application again starts with its own chunk +of the argument value `x_block`. After the `all_gather`, they have a common +value, computed by concatenating the values of `x_block`. + +(Notice that we actually can't set `out_specs=P()` here. For technical +reasons related to automatic differentiation, we consider the output of +`all_gather` not to be guaranteed invariant across devices. If we wanted it to +be guaranteed invariant, we could use `jax.lax.all_gather_invariant`, or in +this case we could just avoid doing the `all_gather` in the function body and +instead just use `out_specs=P('i')` to perform the concatenation.) + +When `tiled=False` (the default), results are stacked along a new axis instead +of concatenated: + +```{code-cell} +@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +def f5(x_block): + print('BEFORE:\n', x_block) + y_block = jax.lax.all_gather(x_block, 'i', tiled=False) + print('AFTER:\n', y_block) + return y_block + +y = f5(x) +print('FINAL RESULT:\n', y) +``` + +We could write the `collective_ref` reference semantics function for +`all_gather` as + +```python +def all_gather_ref(_, x_blocks, *, tiled=False): + combine = jnp.concatenate if tiled else jnp.stack + return [combine(x_blocks)] * len(x_blocks) +``` + +In deep learning, we might use `all_gather`s on parameters in fully sharded +data parallelism (FSDP). + +# psum_scatter + +The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like +`psum` except each function instance gets only one shard of the result: + +Illustration of a psum_scatter computation. + +```{code-cell} +@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +def f6(x_block): + print('BEFORE:\n', x_block) + y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True) + print('AFTER:\n', y_block) + return y_block + +x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2]) +y = f6(x) +print('FINAL RESULT:\n', y) +``` + +As shown by the prints, each resulting `y_block` has a smaller size than the +argument `x_block`, unlike with `psum`. Moreover, compared to `psum`, here +each `y_block` only represents a slice of the sum of the `x_block`s across +function instances. (Even though each function instance gets only one shard of +the sum, the final output `y` is the same as in the `psum` example because +here we use `out_specs=P('i')` to concatenate each function instance's +output.) + +In terms of what values are computed, a `collective_ref` reference +implementation might look like: + +```python +def psum_scatter_ref(i, x_blocks, *, tiled=False): + axis_size = len(x_blocks) + tot = sum(x_blocks) + if tiled: + tot = tot.reshape(axis_size, -1, *tot.shape[1:]) # split leading axis + return [tot[i] for i in range(tot.shape[0])] +``` + +It's not captured in the semantics reference implementation, but +`psum_scatter` is useful because these results can be computed more +efficiently, with less communication, than a full `psum`. In fact, one way to +think of `psum_scatter` is as "the first half of a `psum`, before an +`all_gather`". That is, one way to implement `psum` is: + +```python +def psum(x, axis_name): + summed_chunk = jax.lax.psum_scatter(x, axis_name) + return jax.lax.all_gather(summed_chunk, axis_name) +``` + +Indeed, this implementation is often used on both TPU and GPU! + +The reason `psum_scatter` can require about half the communication as a full +`psum` is illustrated the `ppermute` section. + +Another intuition is that we can use `psum_scatter` to implement a distributed +matrix multiplication with inputs and outputs sharded over the same axis. In +machine learning, `psum_scatter` can be used in tensor-parallel matrix +multiplies or fully-sharded data parallel gradient accumulation, as shown in +the examples to follow. + +## `ppermute` + +The `jax.lax.ppermute` collective provides the most direct way for +function instances to send data to one another. Given a mesh axis and a +list of `(source_index, destination_index)` pairs representing indices along +that mesh axis, `ppermute` sends its argument value from each source function +instance to each destination: + +```{code-cell} +@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +def f7(x_block): + sz = jax.lax.psum(1, 'i') + print('BEFORE:\n', x_block) + y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)]) + print('AFTER:\n', y_block) + return y_block + +y = f7(jnp.arange(8)) +print('FINAL RESULT:\n', y) +``` + +In this case, with just two function instances, each instance's value of +`y_block` is the other's value of `x_block`. + +Source indices and destination indices can't be repeated. If an index does not +appear as a destination, then the value of the corresponding function +instance's result is an array of zeros. + +A `collective_ref` reference implementation could look like + +```python +def ppermute_ref(i, x_blocks, perm): + results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks) + for src, dst in perm: + results[dst] = x_blocks[src] + return results +``` + +Other collectives can be implemented efficiently, in terms of total +communication, using `ppermute`s where each function passes data only to its +neighbors. For example, we could implement `psum_scatter` using a sequence of +`ppermute`s and local additions this way: + +Illustration of a psum_scatter implementation. + +Or, with a numerical example: + +Illustration of a psum_scatter implementation. + + +Intuitively, on each iteration each function instance sends 'up' the value it +received on the previous iteration, and reduces (adds) the value it receives +this iteration. In code, it might look like this: + +```{code-cell} +def psum_scatter(x, axis_name, *, tiled=False): + size = jax.lax.psum(1, axis_name) + idx = jax.lax.axis_index(axis_name) # function instance index along axis_name + if tiled: + x = x.reshape(size, -1, *x.shape[1:]) # split leading axis + shift = partial(jax.lax.ppermute, axis_name=axis_name, + perm=[(i, (i - 1) % size) for i in range(size)]) + for i in range(1, size): + update = shift(x[(idx + i) % size]) + x = x.at[(idx + i + 1) % size].add(update) + return x[idx] +``` + +```{code-cell} +@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +def f8(x_block): + print('BEFORE:\n', x_block) + y_block = psum_scatter(x_block, 'i', tiled=True) + print('AFTER:\n', y_block) + return y_block + +x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2]) +y = f8(x) +print('FINAL RESULT:\n', y) +``` + +On TPU, there are higher-dimensional variants of this algorithm to exploit +multiple bidirectional physical mesh axes. + +Notice that `psum_scatter` is the transpose of `all_gather`. Indeed, a way to +implement `all_gather` in terms of `ppermute` looks like the reverse of the +above process: + +Illustration of an all_gather implementation. + +In deep learning, we might use `ppermute` when implementing SPMD pipeline +parallelism, where we divide our network along its depth into stages and +evaluate the applications of stages in parallel. Or we might use `ppermute` in +parallelizing the evaluation of convolutional layers, where we shard over +spatial axes and thus devices must communicate "halos" to each other. Or it +may be used under-the-hood in tensor-parallel matrix multiplies. + +## `all_to_all` + +A final collective is `all_to_all`, which is essentially a block matrix +transpose operating along one positional axis and one cross-device axis: + +Illustration of an all_to_all computation. + +```{code-cell} +@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +def f9(x_block): + print('BEFORE:\n', x_block) + y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0, + tiled=True) + print('AFTER:\n', y_block) + return y_block + +x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2]) +y = f9(x) +print('FINAL RESULT:\n', y) +``` + +The `split_axis` argument indicates which positional axis should be sharded +and partitioned across the mesh axis. The `concat_axis` argument indicates the +axis along which the communicated results should be concatenated or stacked. + +When `tiled=False` (the default), the `split_axis` axis size must equal the +size of the mesh axis named `axis_name`, and a new axis of that size is +created at position `concat_axis` for the stacked results. When `tiled=True`, +the `split_axis` axis size need only be evenly divisible by the size of the +mesh axis, and results are concatenated along the existing axis `concat_axis`. + +The `collective_ref` reference semantics when `split_axis=0` and +`concat_axis=0` might look like: + +```python +def all_to_all_ref(_, x_blocks, *, tiled=False): + axis_size = len(x_blocks) + if tiled: + splits = [jnp.array_split(x, axis_size) for x in x_blocks] + return [jnp.concatenate(s) for s in zip(*splits)] + else: + splits = [list(x) for x in x_blocks] + return [jnp.stack(s) for s in zip(*splits)] +``` + +In deep learning, we might use `all_to_all` in mixture-of-expert routing, +where we first sort our local batch of examples according to which expert they +should go to, then apply an `all_to_all` to redistribute examples to experts. + +# Toy examples + +How might we use `shard_map` and collective communication in practice? These +examples, while simple, give some idea. + +## Matrix multiplies + +Parallelizing matrix multiplication is central in scaling up deep learning +models, both for training and for inference. When `jax.jit` automatically +parallelizes matrix multiplication, it can use one of several different +strategies, depending on matrix sizes, hardware details, and other factors. How +might we write some of those parallelized routines more explicitly using +`shard_map`? And how can we optimize them to get better compute/communication +overlap and thus improve FLOP utilization? + +```{code-cell} +import jax +import jax.numpy as jnp + +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from jax.experimental.shard_map import shard_map +``` + +```{code-cell} +mesh = Mesh(jax.devices()[:4], ('i',)) + +def device_put(x, pspec): + return jax.device_put(x, NamedSharding(mesh, pspec)) +``` + +### Example 1: `all-gather` on one side + +Consider performing a matrix multiplication where we shard the left-hand side +argument (can think: parameters) on its leading (non-contracting) dimension: + +```{code-cell} +lhs_spec = P('i', None) +lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec) +``` + +And wee shard the right-hand side argument (can think: activations) on its +contracting dimension, with a similar sharding for the output: + +```{code-cell} +rhs_spec = P('i', None) +rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec) +``` + +To perform this matrix multiplication, we can first all-gather the right-hand +side and then perform local matrix multiplies against the sharded left-hand +side: + +```{code-cell} +@jax.jit +@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), + out_specs=rhs_spec) +def matmul_allgather(lhs_block, rhs_block): + rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True) + return lhs_block @ rhs +``` + +```{code-cell} +out = matmul_allgather(lhs, rhs) +print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) +``` + +That's great, but we're not getting any compute/communication overlap +here: before we can start the matmul, we need the all_gather to complete. +Here's a profile using the same code, but on larger example shapes (`(8192, +8192)` for `lhs` and `(8192, 1024)` for `rhs`): + +Profile of an all-gather matmul without overlap. + +We can get compute/communication overlap if instead of calling `all_gather` we +basically inline our above implementation of `all_gather` in terms of +`ppermute`, then interleave steps of the gather permutation with local matrix +multiplies: + +```{code-cell} +@jax.jit +@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), + out_specs=rhs_spec) +def matmul_allgather_overlapped(lhs_block, rhs_block): + size = jax.lax.psum(1, 'i') + idx = jax.lax.axis_index('i') + shift = partial(jax.lax.ppermute, axis_name='i', + perm=[(i, (i + 1) % size) for i in range(size)]) + + B = lhs_block.shape[1] // size + lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1) + + out_block = lhs_blocks(idx) @ rhs_block + for i in range(1, size): + rhs_block = shift(rhs_block) + out_block += lhs_blocks((idx - i) % size) @ rhs_block + return out_block +``` + +```{code-cell} +out = matmul_allgather_overlapped(lhs, rhs) +print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) +``` + +This implementation allows overlap between communication and computation, and +also avoids gathering a large intermediate onto each device. But on TPU it uses +only half the interconnect bandwidth by permuting in only one direction along +the ring. To permute bidirectionally, we just split the blocks in half and send +each half in each direction: + +```{code-cell} +@jax.jit +@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), + out_specs=rhs_spec) +def matmul_allgather_overlapped_bidi(lhs_block, rhs_block): + size = jax.lax.psum(1, 'i') + idx = jax.lax.axis_index('i') + shift_up = partial(jax.lax.ppermute, axis_name='i', + perm=[(i, (i + 1) % size) for i in range(size)]) + shift_dn = partial(jax.lax.ppermute, axis_name='i', + perm=[(i, (i - 1) % size) for i in range(size)]) + + B = lhs_block.shape[1] // size // 2 # half-size blocks + lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1) + + rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0) + out_block = lhs_blocks(idx, 0) @ rhs_block_lo + out_block += lhs_blocks(idx, 1) @ rhs_block_hi + for i in range(1, size): + rhs_block_lo = shift_up(rhs_block_lo) + rhs_block_hi = shift_dn(rhs_block_hi) + out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo + out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi + return out_block +``` + +```{code-cell} +out = matmul_allgather_overlapped_bidi(lhs, rhs) +print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) +``` + +Profile of an all-gather matmul with overlap. + +In practice, to reduce compile times we would probably roll this into a +`jax.lax.fori_loop`. We might also have additional axes of parallelism +involved. + +### Example 2: `psum_scatter` the result + +Another sharding we might start with has both `lhs` and `rhs` sharded along +their contracting dimensions, with the output sharded like `rhs` again: + +```{code-cell} +lhs_spec = P(None, 'i') +lhs = device_put(lhs, lhs_spec) + +rhs_spec = P('i', None) +rhs = device_put(rhs, rhs_spec) +``` + +Here we can use a `reduce_scatter` to perform the contraction sum over shards: + +```{code-cell} +@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), + out_specs=rhs_spec) +def matmul_psumscatter(lhs_block, rhs_block): + out_summand = lhs_block @ rhs_block + return jax.lax.psum_scatter(out_summand, 'i', tiled=True) + +out = matmul_psumscatter(lhs, rhs) +print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) +``` + +But the scattering communication must wait for the entire local matrix multiply +to finish before it can start. To get communication/computation overlap, we can +inline an implementation of `psum_scatter` in terms of `ppermute`, then +interleave the communication steps with local matrix multiplies: + +```{code-cell} +@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), + out_specs=rhs_spec) +def matmul_psumscatter_overlapped(lhs_block, rhs_block): + size = jax.lax.psum(1, 'i') + idx = jax.lax.axis_index('i') + shift = partial(jax.lax.ppermute, axis_name='i', + perm=[(i, (i - 1) % size) for i in range(size)]) + lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1]) # split 1st axis + + out_summand = lhs_block[(idx + 1) % size] @ rhs_block + for i in range(1, size): + out_summand = shift(out_summand) + out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block + return out_summand +``` + +```{code-cell} +out = matmul_psumscatter_overlapped(lhs, rhs) +print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) +``` + +As in the previous example, to fully utilize interconnects on TPU, we'd run a +bidirectional version: + +```{code-cell} +@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), + out_specs=rhs_spec) +def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block): + size = jax.lax.psum(1, 'i') + idx = jax.lax.axis_index('i') + shift_up = partial(jax.lax.ppermute, axis_name='i', + perm=[(i, (i + 1) % size) for i in range(size)]) + shift_dn = partial(jax.lax.ppermute, axis_name='i', + perm=[(i, (i - 1) % size) for i in range(size)]) + + B = lhs_block.shape[0] // size // 2 # half-size blocks + lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0) + + out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block + out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block + for i in range(1, size): + out_summand_lo = shift_up(out_summand_lo) + out_summand_hi = shift_dn(out_summand_hi) + out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block + out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block + return jnp.concatenate([out_summand_lo, out_summand_hi]) +``` + +```{code-cell} +out = matmul_psumscatter_overlapped_bidi(lhs, rhs) +print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) +``` + +## Neural networks + +We can use `shard_map` to parallelize computation in neural networks, either by +itself or in combination with the automatic partitioning in `jax.jit`. This +section has a few examples based on this toy neural network and random data: + +```{code-cell} +import jax +import jax.numpy as jnp + +def predict(params, inputs): + for W, b in params: + outputs = jnp.dot(inputs, W) + b + inputs = jax.nn.relu(outputs) + return outputs + +def loss(params, batch): + inputs, targets = batch + predictions = predict(params, inputs) + return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1)) +``` + +```{code-cell} +def init_layer(key, n_in, n_out): + k1, k2 = jax.random.split(key) + W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in) + b = jax.random.normal(k2, (n_out,)) + return W, b + +def init(key, layer_sizes, batch_size): + key, *keys = jax.random.split(key, len(layer_sizes)) + params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:])) + + key, *keys = jax.random.split(key, 3) + inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0])) + targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1])) + + return params, (inputs, targets) +``` + +```{code-cell} +layer_sizes = [784, 128, 128, 128, 128, 128, 8] +batch_size = 32 + +params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size) +``` + +Compare these examples with the purely [automatic partitioning examples in the +"Distributed arrays and automatic partitioning" +doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). +While in those automatic partitioning examples we don't need to edit the model +functions to use different parallelization strategies, with `shard_map` we +often do. + +### 8-way batch data parallelism + +The simplest multi-device parallelism strategy is to shard the batch of inputs +and targets over multiple devices, replicate the parameters over those devices, +and apply the model in parallel to those shards of data. To evaluate the total +loss, the devices need only communicate with a scalar-sized all-reduce-sum at +the end. (To evaluate the gradient of the loss, the devices must perform +all-reduce-sums of parameter gradients in the backward pass.) + +```{code-cell} +from functools import partial + +from jax.sharding import NamedSharding, Mesh, PartitionSpec as P +from jax.experimental.shard_map import shard_map +from jax.experimental import mesh_utils + +devices = mesh_utils.create_device_mesh((8,)) + +# replicate initial params on all devices, shard data batch over devices +mesh = Mesh(devices, ('batch',)) +batch = jax.device_put(batch, NamedSharding(mesh, P('batch'))) +params = jax.device_put(params, NamedSharding(mesh, P())) + +# adapt the loss function to sum the losses across devices +def loss_dp(params, batch): + @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P(), + check_rep=False) # TODO remove check_rep=False + def loss_spmd(local_batch): + inputs, targets = local_batch + predictions = predict(params, inputs) # use reference 'predict` + local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1)) + return jax.lax.pmean(local_loss, 'batch') + return loss_spmd(batch) +``` + +We can check that the loss and its gradients match the reference (base) model: + +```{code-cell} +print(jax.jit(loss)(params, batch)) +print(jax.jit(loss_dp)(params, batch)) +``` + +```{code-cell} +def allclose(a, b): + return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b)) + +print(allclose(jax.jit(jax.grad(loss))(params, batch), + jax.jit(jax.grad(loss_dp))(params, batch))) +``` + +We can print the compiler IR to inspect the gradient computation and verify +that the collective all-reduce-sum operations happen where we'd expect: at the +end of the forward pass to compute the loss value, and in the backward pass to +compute the total parameter gradients. + +### 8-way fully sharded data parallelism (FSDP) + +Another strategy is to additionally shard the parameters over the devices, +all-gathering each one when the full value is needed for the `jnp.dot` or bias +addition. Since we only have one full parameter in local device memory at a +time, rather than keeping all parameters in all device memories as in the +preceding DP example, we free up significant memory that we can use for larger +models or larger batch sizes. And because XLA will overlap computation and +inter-device communication, the wall-clock time doesn't suffer. + +So now we need collectives in two places: the model prediction function +`predict` needs to all-gather the parameters before they're used, and as in the +DP case the loss function needs to sum the local losses to compute the total +loss. + +There's one other ingredient we need: we don't want to store the fully gathered +parameters from the forward pass for use on the backward pass. Instead, we want +to gather them again on the backward pass. We can express that by using +`jax.remat` with a [custom +policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable) +(or a `custom_vjp`), though XLA typically does that rematerialization +automatically. + +This general [FSDP +approach](https://engineering.fb.com/2021/07/15/open-source/fsdp/) is similar +to [weight update sharding (WUS)](https://arxiv.org/abs/2004.13336) and +[ZeRO-3](https://arxiv.org/abs/1910.02054). + +```{code-cell} +# shard data batch *and params* over devices +mesh = Mesh(devices, ('batch',)) +batch = jax.device_put(batch, NamedSharding(mesh, P('batch'))) +params = jax.device_put(params, NamedSharding(mesh, P('batch'))) + +# adapt the prediction function to gather weights just before their use, +# and to re-gather them on the backward pass (rather than saving them) +@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather') +def predict_fsdp(params_frag, inputs): + for W_frag, b_frag in params_frag: + W = jax.lax.all_gather(W_frag, 'batch', tiled=True) + b = jax.lax.all_gather(b_frag, 'batch', tiled=True) + outputs = jnp.dot(inputs, W) + b + inputs = jax.nn.relu(outputs) + return outputs + +def loss_fsdp(params, batch): + @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P()) + def loss_spmd(local_params, local_batch): + inputs, targets = local_batch + predictions = predict_fsdp(local_params, inputs) + local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1)) + return jax.lax.pmean(local_loss, 'batch') + return loss_spmd(params, batch) +``` + +Again we can check that the loss and its gradients match the reference model: + +```{code-cell} +print(jax.jit(loss)(params, batch)) +print(jax.jit(loss_fsdp)(params, batch)) + +print(allclose(jax.jit(jax.grad(loss))(params, batch), + jax.jit(jax.grad(loss_fsdp))(params, batch))) +``` + +### 8-way tensor parallelism (TP) + +Usually we don't use tensor model parallelism by itself, but seeing it in +isolation is a good warmup on parallel matrix multiplication. It's also a good +example of using `shard_map` in a library function, called in a larger +`jit`-based computation. + +The parallelization idea is that we'll keep the data/activations sharded over +its feature axis (rather than its batch axis), and we'll similarly shard weight +matrices over their input-feature axis (and biases over their feature axis). +Then to perform the parallel matrix multiplication, we'll perform local matrix +multiplications followed by a `psum_scatter` to sum the local results and +efficiently scatter the result's shards. + +```{code-cell} +devices = mesh_utils.create_device_mesh((8,)) +mesh = Mesh(devices, ('feats',)) + +batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats'))) +params = jax.device_put(params, NamedSharding(mesh, P('feats'))) + +def predict_tp(params, inputs): + for W, b in params: + outputs = gemm_tp(inputs, W, b) + inputs = jax.nn.relu(outputs) + return outputs + +@partial(shard_map, mesh=mesh, + in_specs=(P(None, 'feats'), P('feats', None), P('feats')), + out_specs=P(None, 'feats')) +def gemm_tp(inputs, W, b): + block_result = jnp.dot(inputs, W) + return jax.lax.psum_scatter(block_result, 'feats', + scatter_dimension=1, tiled=True) + b + +def loss_tp(params, batch): + inputs, targets = batch + predictions = predict_tp(params, inputs) + return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum! +``` + +### FSDP + TP, with `shard_map` at the top level + +We can compose these strategies together, using multiple axes of parallelism. + +```{code-cell} +devices = mesh_utils.create_device_mesh((4, 2)) +mesh = Mesh(devices, ('batch', 'feats')) + +batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats'))) +params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats')))) + +# mostly same as previous predict_fsdp definition, except we call gemm_tp +@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather') +def predict_fsdp_tp(params_frag, inputs): + for W_frag, b_frag in params_frag: + W = jax.lax.all_gather(W_frag, 'batch', tiled=True) + b = jax.lax.all_gather(b_frag, 'batch', tiled=True) + block_result = jnp.dot(inputs, W) + outputs = jax.lax.psum_scatter(block_result, 'feats', + scatter_dimension=1, tiled=True) + b + inputs = jax.nn.relu(outputs) + return outputs + +@partial(shard_map, mesh=mesh, + in_specs=(P(('feats', 'batch')), P('batch', 'feats'))) +def loss_fsdp_tp(local_params, local_batch): + inputs, targets = local_batch + predictions = predict_fsdp_tp(local_params, inputs) + sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats') + return jax.lax.pmean(jnp.mean(sq_err), 'batch') +``` + +Notice how we have to do _two_ collective reductions: one over `'feats'` and +one over `'batch'`. In the pure TP example, we didn't write the `'feats'` +reduction explicitly because we only used `shard_map` within `gemm_tp`; in the +caller `loss_tp`, the compiler automatically translated our use of `jnp.sum` to +perform a `psum` as needed given the sharded result returned by `predict_tp`. + +```{code-cell} +print(jax.jit(loss)(params, batch)) +print(jax.jit(loss_fsdp_tp)(params_, batch_)) + +print(allclose(jax.jit(jax.grad(loss))(params, batch), + jax.jit(jax.grad(loss_fsdp_tp))(params, batch))) +``` + +### SPMD pipeline parallelism (PP) + +With pipeline parallelism we aim to parallelize the evaluation of layers at +different depths in our network. For example, one device might compute the +application of the first layer while another device computes the application of +the second; when they finish, the first device passes its results to the second +while the second passes its results to the device responsible for the third +layer, and the process repeats. In general the number of pipeline stages may be +different from the number of layers, as each stage may be responsible for +multiple layers. + +With SPMD pipelining, we exploit the fact that most layers in the network apply +the computation, just with different parameter values. In particular, we can +stack together all the parameters except for those for the first and last +layers, then use a `shard_map` to map over blocks of those layer parameters, +where each block of parameters corresponds to a pipeline stage. We then use the +`jax.lax.ppermute` collective to shift data down the parallel pipeline. + +This particular pipelining strategy is essentially [the GPipe +strategy](https://arxiv.org/abs/1811.06965). There are several variants, as +well as quite different strategies, and which is appropriate can depend on the +speed of the networking between stages and batch sizes. But for this tutorial +we'll focus on just one strategy. + +First, we choose some pipeline parameters: + +```{code-cell} +L = len(params) - 2 # num layers, excluding first and last +N = batch_size # batch size +F = params[0][0].shape[1] # num features + +# choose some pipeline parameters +S = 2 # number of stages +B = 8 # size of each microbatch +assert L % S == 0, "S (number of stages) must divide L (number of inner layers)" + +# compute some useful quantities +M, ragged = divmod(N, B) # M is number of microbatches +assert not ragged, "B (size of each microbatch) must divide total batch size" +K, ragged = divmod(M, S) # K is microbatches per stage +assert not ragged, "S (number of stages) must divide number of microbatches" +print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total') +print(f'{B} examples per microbatch, {M} microbatches total') +``` + +```{code-cell} +mesh = Mesh(jax.devices()[:S], ('stages',)) + +def predict_pp(params, inputs): + (W_first, b_first), inner_params, (W_last, b_last) = params + inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first) + inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]), + inner_params, inputs) + outputs = jnp.dot(inputs, W_last) + b_last + return outputs + +@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')), + out_specs=P()) +def loss_pp(params, batch): + inputs, targets = batch + predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1) + local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1)) + return jax.lax.pmean(local_loss, 'stages') +``` + +```{code-cell} +def spmd_pipeline(fn, stage_params, inputs): + stage = jax.lax.axis_index('stages') + outputs = jnp.zeros_like(inputs) * jnp.nan + state = jnp.zeros((L // S, B, F)) * jnp.nan + for i in range(M+L-1): + state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0])) + state = jax.vmap(fn)(stage_params, state) + outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K])) + state, inputs, outputs = shift(i, state, inputs, outputs) + outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)]) + return outputs + +def shift(i, state, inputs, outputs): + sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)]) + state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1)) + if (i % K) == (-1 % K): + inputs = sh(inputs, +1) + if ((i-L+1) % K) == (-1 % K): + outputs = sh(outputs, +1) + return state, inputs, outputs +``` + +```{code-cell} +first_params, *inner_params, last_params = params +Ws, bs = zip(*inner_params) +params_stacked = jnp.stack(Ws), jnp.stack(bs) +first_params = jax.device_put(first_params, NamedSharding(mesh, P())) +params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages'))) +last_params = jax.device_put(last_params, NamedSharding(mesh, P())) +params_ = first_params, params_stacked, last_params + +batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages'))) +``` + +```{code-cell} +print(jax.jit(loss)(params, batch)) +print(jax.jit(loss_pp)(params_, batch_)) +``` + +```{code-cell} +_ = jax.jit(jax.grad(loss_pp))(params_, batch_) # don't crash +```