diff --git a/docs/advanced_guide.rst b/docs/advanced_guide.rst
index 029f72e83..eddca16d5 100644
--- a/docs/advanced_guide.rst
+++ b/docs/advanced_guide.rst
@@ -18,6 +18,7 @@ This section contains examples and tutorials on more advanced topics, such as Mu
multi_process
notebooks/Distributed_arrays_and_automatic_parallelization
+ notebooks/shard_map
notebooks/xmap_tutorial
.. toctree::
diff --git a/docs/conf.py b/docs/conf.py
index 917382d67..5becd88de 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -209,6 +209,7 @@ nb_execution_excludepatterns = [
'notebooks/xmap_tutorial.*',
'notebooks/Distributed_arrays_and_automatic_parallelization.*',
'notebooks/autodiff_remat.*',
+ 'notebooks/shard_map.*',
# Requires accelerators
'pallas/quickstart.*',
]
diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb
new file mode 100644
index 000000000..b4ce840bb
--- /dev/null
+++ b/docs/notebooks/shard_map.ipynb
@@ -0,0 +1,1964 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "97c57a94",
+ "metadata": {},
+ "source": [
+ "# Intro\n",
+ "\n",
+ "`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n",
+ "\n",
+ "`shard_map` is complementary to, and comopsable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n",
+ "\n",
+ "If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))\n",
+ "\n",
+ "By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies.\n",
+ "\n",
+ "## So, let's see a `shard_map`!\n",
+ "\n",
+ "Without further ado, here's a toy example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d3f562ec",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from functools import partial\n",
+ "\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "\n",
+ "from jax.sharding import Mesh, PartitionSpec as P\n",
+ "from jax.experimental import mesh_utils\n",
+ "from jax.experimental.shard_map import shard_map"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "64a910c1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "devices = mesh_utils.create_device_mesh((4, 2))\n",
+ "mesh = Mesh(devices, axis_names=('x', 'y'))\n",
+ "\n",
+ "a = jnp.arange( 8 * 16.).reshape(8, 16)\n",
+ "b = jnp.arange(16 * 4.).reshape(16, 4)\n",
+ "\n",
+ "@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),\n",
+ " out_specs=P('x', None))\n",
+ "def matmul_basic(a_block, b_block):\n",
+ " # a_block: f32[2, 8]\n",
+ " # b_block: f32[8, 4]\n",
+ " c_partialsum = jnp.dot(a_block, b_block)\n",
+ " c_block = jax.lax.psum(c_partialsum, 'y')\n",
+ " # c_block: f32[2, 4]\n",
+ " return c_block\n",
+ "\n",
+ "c = matmul_basic(a, b) # c: f32[8, 4]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "85c92753",
+ "metadata": {},
+ "source": [
+ "This function computes a matrix multiply in parallel by performing local block matrix multipiles followed by a collective sum operation. We can check the result is correct:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "09c382bb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from jax.tree_util import tree_map, tree_all\n",
+ "\n",
+ "def allclose(a, b):\n",
+ " return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))\n",
+ "\n",
+ "allclose(c, jnp.dot(a, b))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7618d7f8",
+ "metadata": {},
+ "source": [
+ "The result is sharded along its rows:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "932af90f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "jax.debug.visualize_array_sharding(c)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a79329b0",
+ "metadata": {},
+ "source": [
+ "At a high level, `shard_map` is kind of like `vmap` or `pmap`, in that we're\n",
+ "mapping a function over pieces of array data, but notice that\n",
+ "* `shard_map` slices up inputs into blocks (and the output is formed by concatenating result blocks), keeping the rank the same, whereas `vmap` would reduce the rank by mapping away an axis;\n",
+ "* the `mesh` argument lets us control precise device placement of computation and results;\n",
+ "* we're mapping over multiple data axes at once, and setting up multiple axis names for collectives (both `'x'` and `'y'` here);\n",
+ "* since we're not using `jax.jit` yet, everything is eagerly evaluated, and we can even `print` intermediate values for debugging.\n",
+ "\n",
+ "The above code is performing the same computation as this `jax.jit` automatic parallelization code:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f9cf3cdd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from jax.sharding import NamedSharding\n",
+ "\n",
+ "a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))\n",
+ "b = jax.device_put(b, NamedSharding(mesh, P('y', None)))\n",
+ "\n",
+ "@jax.jit\n",
+ "def matmul_reference(a, b):\n",
+ " c = jnp.dot(a, b)\n",
+ " return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))\n",
+ "\n",
+ "c_ref = matmul_reference(a, b)\n",
+ "allclose(c_ref, jnp.dot(a, b))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ab826433",
+ "metadata": {},
+ "source": [
+ "We can think of `shard_map` as performing a `device_put` or\n",
+ "`with_sharding_constraint` on its inputs according to its `mesh` and `in_specs`\n",
+ "arguments, so the blocks over which `matmul_basic` operates are the same as in\n",
+ "`matmul_reference`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8497b5f8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print('a blocks:'); jax.debug.visualize_array_sharding(a)\n",
+ "print('b blocks:'); jax.debug.visualize_array_sharding(b)\n",
+ "print('c blocks:'); jax.debug.visualize_array_sharding(c)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "532fe5f6",
+ "metadata": {},
+ "source": [
+ "## Slow down, start with the basics!\n",
+ "\n",
+ "### Rank-reducing vs rank-preserving maps\n",
+ "\n",
+ "We can think of `vmap` and `pmap` as unstacking each array input along an axis\n",
+ "(e.g. unpacking a 2D matrix into its 1D rows), applying its body function to\n",
+ "each piece, and stacking the results back together, at least when collectives\n",
+ "aren't involved:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "cb8e1883",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def check_vmap(f, xs):\n",
+ " ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)\n",
+ " expected = jnp.stack([f(x) for x in xs]) # vmap reference semantics\n",
+ " print(allclose(ans, expected))\n",
+ "\n",
+ "check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9d55b900",
+ "metadata": {},
+ "source": [
+ "For example, if `xs` had shape `f32[8,5]` then each `x` would have shape\n",
+ "`f32[5]`, and if each `f(x)` had shape `f32[3,7]` then the final stacked result\n",
+ "`vmap(f)(xs)` would have shape `f32[8,3,7]`. That is, each application of the\n",
+ "body function `f` takes as argument inputs with one fewer axis than the\n",
+ "corresponding argument to `vmap(f)`. We can say these are _rank-reducing maps_\n",
+ "with unstacking/stacking of inputs/outputs.\n",
+ "\n",
+ "The number of logical applications of `f`, or _instances_ of `f`, is determined\n",
+ "by the size of the input axis being mapped over: for example, if we map over an\n",
+ "input axis of size 8, semantically we get 8 logical applications of the\n",
+ "function.\n",
+ "\n",
+ "In contrast, `shard_map` does not have this rank-reducing behavior. Instead, we\n",
+ "can think of it as slicing (or \"unconcatenating\") along input axes into blocks,\n",
+ "applying the body function, and concatenating the results back together (again\n",
+ "when collectives aren't involved):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "30e89f3f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "devices = np.array(jax.devices()[:4])\n",
+ "mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4\n",
+ "\n",
+ "def check_shmap(f, y):\n",
+ " ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)\n",
+ " expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])\n",
+ " print(allclose(ans, expected))\n",
+ "\n",
+ "check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "985ff202",
+ "metadata": {},
+ "source": [
+ "Recall that jnp.split slices its input into equally-sized blocks with the same\n",
+ "rank, so that if in the above example `y` had shape `f32[8,5]` then each\n",
+ "`y_blk` would have shape `f32[2,5]`, and if each `f(y_blk)` had shape\n",
+ "`f32[3,7]` then the final concatenated result `shard_map(f, ...)(y)` would have\n",
+ "shape `f32[12,7]`. So `shard_map` maps over _shards_, or blocks, of its inputs.\n",
+ "We can say it's a _rank-preserving map_ with unconcatenating/concatenating of\n",
+ "its inputs/outputs.\n",
+ "\n",
+ "The number of logical applications of `f` is determined by the mesh size, not\n",
+ "by any input axis size: for example, if we have a mesh of total size 4 (i.e.\n",
+ "over 4 devices) then semantically we get 4 logical applications of the\n",
+ "function, corresponding to the 4 devices physically computing them.\n",
+ "\n",
+ "### Controlling how each input is split (unconcatenated) and tiled with `in_specs`\n",
+ "\n",
+ "Each of the `in_specs` identifies some of the corresponding input array's axes\n",
+ "with mesh axes by name using `PartitionSpec`s, representing how to split (or\n",
+ "unconcatenate) that input into the blocks to which the body function is\n",
+ "applied. That identification determines the shard sizes; when an input axis is\n",
+ "identified with a mesh axis, the input is split (unconcatenated) along that\n",
+ "logical axis into a number of pieces equal to the corresponding mesh axis size.\n",
+ "(It's an error if the corresponding mesh axis size does not evenly divide the\n",
+ "input array axis size.) If an input's pspec does not mention a mesh axis name,\n",
+ "then there's no splitting over that mesh axis. For example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "08555009",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "devices = mesh_utils.create_device_mesh((4, 2))\n",
+ "mesh = Mesh(devices, ('i', 'j'))\n",
+ "\n",
+ "@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n",
+ "def f1(x_block):\n",
+ " print(x_block.shape) # prints (3, 12)\n",
+ " return x_block\n",
+ "\n",
+ "x1 = jnp.arange(12 * 12).reshape(12, 12)\n",
+ "y = f1(x1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f191681b",
+ "metadata": {},
+ "source": [
+ "Here, because the input pspec did not mention the mesh axis name `'j'`, no\n",
+ "input array axis is split over that mesh axis; similarly, because the second\n",
+ "axis of the input array is not identified with (and hence split over) any mesh\n",
+ "axis, application of `f1` gets a full view of the input along that axis.\n",
+ "\n",
+ "When a mesh axis is not mentioned in an input pspec, we can always rewrite to a\n",
+ "less efficient program where all mesh axes are mentioned but the caller\n",
+ "performs a `jnp.tile`, for example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2515872d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))\n",
+ "def f2(x_block):\n",
+ " print(x_block.shape)\n",
+ " return x_block\n",
+ "\n",
+ "x = jnp.arange(12 * 12).reshape(12, 12)\n",
+ "x_ = jnp.tile(x, (1, mesh.shape['j'])) # x_ has shape (12, 24)\n",
+ "y = f2(x_) # prints (3,12), and f1(x) == f2(x_)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8be0595e",
+ "metadata": {},
+ "source": [
+ "In other words, because each input pspec can mention each mesh axis name zero\n",
+ "or one times, rather than having to mention each name exactly once, we can say\n",
+ "that in addition to the `jnp.split` built into its input, `shard_map` also has\n",
+ "a `jnp.tile` built into its input, at least logically (though the tiling may\n",
+ "not need to be carried out physically, depending on the arguments' physical\n",
+ "sharding layout). The tiling to use is not unique; we could also have tiled\n",
+ "along the first axis, and used the pspec `P(('j', 'i'), None)`.\n",
+ "\n",
+ "Physical data movement is possible on inputs, as each device needs to have a\n",
+ "copy of the appropriate data.\n",
+ "\n",
+ "### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs`\n",
+ "\n",
+ "Analogously to the input side, each of the `out_specs` identifies some of the\n",
+ "corresponding output array's axes with mesh axes by name, representing how the\n",
+ "output blocks (one for each application of the body function, or equivalently\n",
+ "one for each physical device) should be assembled back together to form the\n",
+ "final output value. For example, in both the `f1` and `f2` examples above the\n",
+ "`out_specs` indicate we should form the final output by concatenating together\n",
+ "the block results along both axes, resulting in both cases an array `y` of\n",
+ "shape `(12, 24)`. (It's an error if an output shape of the body function, i.e.\n",
+ "an output block shape, has a rank too small for the concatenation described by\n",
+ "the corresponding output pspec.)\n",
+ "\n",
+ "When a mesh axis name is not mentioned in an output pspec, it represents an\n",
+ "un-tiling: when the user writes an output pspec which does not mention one of\n",
+ "the mesh axis names, they promise that the output blocks are equal along that\n",
+ "mesh axis, and so only one block along that axis is used in the output (rather\n",
+ "than concatenating all the blocks together along that mesh axis). For example,\n",
+ "using the same mesh as above:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "42fa392a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = jnp.array([[3.]])\n",
+ "\n",
+ "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()\n",
+ "print(z) # prints the same as jnp.tile(x, (4, 2))\n",
+ "\n",
+ "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()\n",
+ "print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))\n",
+ "\n",
+ "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()\n",
+ "print(z) # prints the same as jnp.tile(x, (1, 1)), or just x"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9c25db8c",
+ "metadata": {},
+ "source": [
+ "The body function closing over an array value is equivalent to passing it as an\n",
+ "augment with a corresponding input pspec of P(None, None). As another example,\n",
+ "following more closely to the other examples above:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b3c0a0e3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))\n",
+ "def f3(x_block):\n",
+ " return jax.lax.psum(x_block, 'j')\n",
+ "\n",
+ "x = jnp.arange(12 * 12).reshape(12, 12)\n",
+ "y3 = f3(x)\n",
+ "print(y3.shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0adc8960",
+ "metadata": {},
+ "source": [
+ "The result has a second axis size of 6, half the size of the input's second\n",
+ "axis. In this case, the un-tile expressed by not mentioning the mesh axis name\n",
+ "`'j'` in the output pspec was safe because of the collective `psum`, which\n",
+ "ensures each output block is equal along the corresponding mesh axis. Here are\n",
+ "two more examples where we vary which mesh axes are mentioned in the output\n",
+ "pspec:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b65636dd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n",
+ "def f4(x_block):\n",
+ " return jax.lax.psum(x_block, 'i')\n",
+ "\n",
+ "x = jnp.arange(12 * 12).reshape(12, 12)\n",
+ "y4 = f4(x)\n",
+ "print(y4.shape) # (3,12)\n",
+ "\n",
+ "\n",
+ "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))\n",
+ "def f5(x_block):\n",
+ " return jax.lax.psum(x_block, ('i', 'j'))\n",
+ "\n",
+ "y5 = f5(x)\n",
+ "print(y5.shape) # (3,6)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "39218e05",
+ "metadata": {},
+ "source": [
+ "On the physical side, not mentioning a mesh axis name in an output pspec\n",
+ "assembles an `Array` from the output device buffers with replicated layout\n",
+ "along that mesh axis.\n",
+ "\n",
+ "There is no runtime check that the output blocks are actually equal along a\n",
+ "mesh axis to be un-tiled along, or equivalently that the corresponding physical\n",
+ "buffers have equal values and thus can be interpreted as a replicated layout\n",
+ "for a single logical array. But we can provide a static check mechanism which\n",
+ "raises an error on all potentially-incorrect programs.\n",
+ "\n",
+ "Because the `out_specs` can mention mesh axis names zero or one times, and\n",
+ "because they can be mentioned in any order, we can say that in addition to the\n",
+ "`jnp.concatenate` built into its output, `shard_map` also has both an _untile_\n",
+ "and a _block transpose_ built into its output.\n",
+ "\n",
+ "Physical data movement is not possible on outputs, no matter the output pspec.\n",
+ "Instead, `out_specs` just encodes how to assemble the block outputs into\n",
+ "`Array`s, or physically how to interpret the buffers across devices as the\n",
+ "physical layout of a single logical `Array`.\n",
+ "\n",
+ "# API Specification\n",
+ "\n",
+ "```python\n",
+ "from jax.sharding import Mesh\n",
+ "Specs = PyTree[PartitionSpec]\n",
+ "\n",
+ "def shard_map(\n",
+ " f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,\n",
+ " auto: collections.abc.Set[AxisName] = frozenset([]),\n",
+ " check_rep: bool = True,\n",
+ ") -> Callable:\n",
+ " ...\n",
+ "```\n",
+ "where:\n",
+ "* communication collectives like `psum` in the body of `f` can mention the axis names of `mesh`;\n",
+ "* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`;\n",
+ "* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n",
+ "* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the obdy, as in the caller, rather than manually;\n",
+ "* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)).\n",
+ "\n",
+ "The shapes of the arguments passed to `f` have the same ranks as the arguments\n",
+ "passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed\n",
+ "from the shape `shape` of the corresponding argument to `shard_map`-of-`f` and\n",
+ "the corresponding `PartitionSpec` `spec` as roughly\n",
+ "`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`.\n",
+ "\n",
+ "# Collectives tutorial\n",
+ "\n",
+ "A `shard_map` need not be a pure map: function applications can communicate\n",
+ "with each other via _collectives_, using axis names defined in the `mesh`\n",
+ "argument.\n",
+ "\n",
+ "Recall that `shard_map` maps a function over shards, or blocks, of input data,\n",
+ "so that this:\n",
+ "\n",
+ "```python\n",
+ "mesh = Mesh(jax.devices(), ('i',))\n",
+ "x = jnp.arange(16.)\n",
+ "f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))\n",
+ "y = f_shmapped(x)\n",
+ "```\n",
+ "\n",
+ "Computes the same values, evaluating applications of `f` to the same argument\n",
+ "values, as this reference function:\n",
+ "\n",
+ "```python\n",
+ "def f_shmapped_ref(x):\n",
+ " x_blocks = jnp.array_split(x, mesh.shape[0])\n",
+ " y_blocks = [f(x_blk) for x_blk in x_blocks]\n",
+ " return jnp.concatenate(y_blocks)\n",
+ "```\n",
+ "\n",
+ "We call these applications of `f` to different argument shards _function\n",
+ "instances_. Each function instance is executed on a different device (or subset\n",
+ "of devices).\n",
+ "\n",
+ "These reference semantics work when `f` has no communication collectives in\n",
+ "it. But what if we want the function instances to communicate, corresponding\n",
+ "to having cross-device communication? That is, what are the reference\n",
+ "semantics when `f` contains a collective? Say `f` has just one collective, and\n",
+ "is of the form\n",
+ "\n",
+ "```python\n",
+ "def f(x_blk):\n",
+ " z_blk = f_part1(x_blk)\n",
+ " u_blk = collective(z_blk, axis_name)\n",
+ " v_blk = f_part2(x_blk, z_blk, u_blk)\n",
+ " return v_blk\n",
+ "```\n",
+ "\n",
+ "where we're assuming there's only one mesh axis we're mapping over, and\n",
+ "`axis_name` is the corresponding name for it. Then the reference semantics\n",
+ "would look more like:\n",
+ "\n",
+ "```python\n",
+ "def f_shmapped_ref(x):\n",
+ " x_blocks = jnp.array_split(x, mesh.shape[0])\n",
+ " z_blocks = [f_part1(x_blk) for x_blk in x_blocks]\n",
+ " u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))]\n",
+ " v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk\n",
+ " in zip(x_blocks, z_blocks, u_blocks)]\n",
+ " return jnp.concatenate(v_blocks)\n",
+ "```\n",
+ "\n",
+ "Notice that `collective_ref` might depend on all the `z_blocks`. That is,\n",
+ "while `f_part1` and `f_part2` are mapped over blocks independently, a\n",
+ "collective introduces some amount of cross-block dependence. Physically, that\n",
+ "means communication across devices. Exactly what communication happens, and\n",
+ "what values are computed, depend on the collective.\n",
+ "\n",
+ "## `psum`\n",
+ "\n",
+ "The simplest collective may be `jax.lax.psum`, which computes an\n",
+ "all-reduce-sum along a device mesh axis (or multiple axes).\n",
+ "Here's a toy example:\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "dac420e5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "from jax import lax\n",
+ "\n",
+ "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n",
+ "from jax.experimental.shard_map import shard_map"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "38dd14b6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mesh1d = Mesh(jax.devices()[:4], ('i',))\n",
+ "\n",
+ "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n",
+ "def f1(x_block):\n",
+ " print('BEFORE:\\n', x_block)\n",
+ " y_block = jax.lax.psum(x_block, 'i')\n",
+ " print('AFTER:\\n', y_block)\n",
+ " return y_block"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "22f0b947",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])\n",
+ "y = f1(x)\n",
+ "print('FINAL RESULT:\\n', y)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e32df19e",
+ "metadata": {},
+ "source": [
+ "The prints show that each function application starts with its own chunk of\n",
+ "the argument value `x_block`. After the `psum`, each function application has\n",
+ "the same value of `y_block`, computed by summing the applications' `x_block`\n",
+ "values together.\n",
+ "\n",
+ "In the case where there's a single axis name in the computation, we could say\n",
+ "that the `collective_ref` reference implementation for `psum` is\n",
+ "\n",
+ "```python\n",
+ "def psum_ref(_, x_blocks):\n",
+ " tot = sum(x_blocks)\n",
+ " return [tot] * len(x_blocks)\n",
+ "```\n",
+ "\n",
+ "Notice also that because `f1` returns `y_block`, the result of a `psum` over\n",
+ "`'i'`, we can use `out_specs=P()` so the caller gets a single logical copy of\n",
+ "the result value, rather than a tiled result.\n",
+ "\n",
+ "When there is more than one mesh axis, we can perform a `psum` over\n",
+ "each one separately, or over multiple axes at once:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "39fba427",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))\n",
+ "\n",
+ "@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n",
+ "def f2(x_block):\n",
+ " print('BEFORE:\\n', x_block)\n",
+ " y_block = jax.lax.psum(x_block, 'i')\n",
+ " print('AFTER:\\n', y_block)\n",
+ " return y_block\n",
+ "\n",
+ "y = f2(jnp.arange(16).reshape(4, 4))\n",
+ "print('FINAL RESULT:\\n', y)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d2bdd59b",
+ "metadata": {},
+ "source": [
+ "By applying a `psum` over mesh axis `'i'`, we get values of `y_block` which\n",
+ "are equal along axis '`i'`, but not axis `'j'`. (So we can use\n",
+ "`out_specs=P(None, 'j')` to get a single logical result along that axis.)\n",
+ "\n",
+ "If we apply the `psum` over both axes, the `y_block` value is equal along both\n",
+ "axes:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2919056c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))\n",
+ "def f3(x_block):\n",
+ " print('BEFORE:\\n', x_block)\n",
+ " y_block = jax.lax.psum(x_block, ('i', 'j'))\n",
+ " print('AFTER:\\n', y_block)\n",
+ " return y_block\n",
+ "\n",
+ "y = f3(jnp.arange(16).reshape(4, 4))\n",
+ "print('FINAL RESULT:\\n', y)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e9d1f748",
+ "metadata": {},
+ "source": [
+ "In machine learning, we often use `psum` to compute total losses or, when we\n",
+ "have a `grad` inside the `shard_map`ped function body, total gradients.\n",
+ "\n",
+ "In the sequel, we'll see how `psum` can be implemented in terms of other\n",
+ "primitives, which gives some intuition about its communication cost.\n",
+ "\n",
+ "## `all_gather`\n",
+ "\n",
+ "Another fundamental operation is gathering array shards along an axis, so that\n",
+ "each function application has a full copy of the data along that axis:\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ac45aafc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
+ "def f4(x_block):\n",
+ " print('BEFORE:\\n', x_block)\n",
+ " y_block = jax.lax.all_gather(x_block, 'i', tiled=True)\n",
+ " print('AFTER:\\n', y_block)\n",
+ " return y_block\n",
+ "\n",
+ "x = jnp.array([3, 9, 5, 2])\n",
+ "y = f4(x)\n",
+ "print('FINAL RESULT:\\n', y)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a16048aa",
+ "metadata": {},
+ "source": [
+ "The prints show that each function application again starts with its own chunk\n",
+ "of the argument value `x_block`. After the `all_gather`, they have a common\n",
+ "value, computed by concatenating the values of `x_block`.\n",
+ "\n",
+ "(Notice that we actually can't set `out_specs=P()` here. For technical\n",
+ "reasons related to automatic differentiation, we consider the output of\n",
+ "`all_gather` not to be guaranteed invariant across devices. If we wanted it to\n",
+ "be guaranteed invariant, we could use `jax.lax.all_gather_invariant`, or in\n",
+ "this case we could just avoid doing the `all_gather` in the function body and\n",
+ "instead just use `out_specs=P('i')` to perform the concatenation.)\n",
+ "\n",
+ "When `tiled=False` (the default), results are stacked along a new axis instead\n",
+ "of concatenated:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7d660032",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
+ "def f5(x_block):\n",
+ " print('BEFORE:\\n', x_block)\n",
+ " y_block = jax.lax.all_gather(x_block, 'i', tiled=False)\n",
+ " print('AFTER:\\n', y_block)\n",
+ " return y_block\n",
+ "\n",
+ "y = f5(x)\n",
+ "print('FINAL RESULT:\\n', y)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "960bc5c3",
+ "metadata": {},
+ "source": [
+ "We could write the `collective_ref` reference semantics function for\n",
+ "`all_gather` as\n",
+ "\n",
+ "```python\n",
+ "def all_gather_ref(_, x_blocks, *, tiled=False):\n",
+ " combine = jnp.concatenate if tiled else jnp.stack\n",
+ " return [combine(x_blocks)] * len(x_blocks)\n",
+ "```\n",
+ "\n",
+ "In deep learning, we might use `all_gather`s on parameters in fully sharded\n",
+ "data parallelism (FSDP).\n",
+ "\n",
+ "# psum_scatter\n",
+ "\n",
+ "The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like\n",
+ "`psum` except each function instance gets only one shard of the result:\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fcc18ee8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
+ "def f6(x_block):\n",
+ " print('BEFORE:\\n', x_block)\n",
+ " y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)\n",
+ " print('AFTER:\\n', y_block)\n",
+ " return y_block\n",
+ "\n",
+ "x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])\n",
+ "y = f6(x)\n",
+ "print('FINAL RESULT:\\n', y)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b645b6b2",
+ "metadata": {},
+ "source": [
+ "As shown by the prints, each resulting `y_block` has a smaller size than the\n",
+ "argument `x_block`, unlike with `psum`. Moreover, compared to `psum`, here\n",
+ "each `y_block` only represents a slice of the sum of the `x_block`s across\n",
+ "function instances. (Even though each function instance gets only one shard of\n",
+ "the sum, the final output `y` is the same as in the `psum` example because\n",
+ "here we use `out_specs=P('i')` to concatenate each function instance's\n",
+ "output.)\n",
+ "\n",
+ "In terms of what values are computed, a `collective_ref` reference\n",
+ "implementation might look like:\n",
+ "\n",
+ "```python\n",
+ "def psum_scatter_ref(i, x_blocks, *, tiled=False):\n",
+ " axis_size = len(x_blocks)\n",
+ " tot = sum(x_blocks)\n",
+ " if tiled:\n",
+ " tot = tot.reshape(axis_size, -1, *tot.shape[1:]) # split leading axis\n",
+ " return [tot[i] for i in range(tot.shape[0])]\n",
+ "```\n",
+ "\n",
+ "It's not captured in the semantics reference implementation, but\n",
+ "`psum_scatter` is useful because these results can be computed more\n",
+ "efficiently, with less communication, than a full `psum`. In fact, one way to\n",
+ "think of `psum_scatter` is as \"the first half of a `psum`, before an\n",
+ "`all_gather`\". That is, one way to implement `psum` is:\n",
+ "\n",
+ "```python\n",
+ "def psum(x, axis_name):\n",
+ " summed_chunk = jax.lax.psum_scatter(x, axis_name)\n",
+ " return jax.lax.all_gather(summed_chunk, axis_name)\n",
+ "```\n",
+ "\n",
+ "Indeed, this implementation is often used on both TPU and GPU!\n",
+ "\n",
+ "The reason `psum_scatter` can require about half the communication as a full\n",
+ "`psum` is illustrated the `ppermute` section.\n",
+ "\n",
+ "Another intuition is that we can use `psum_scatter` to implement a distributed\n",
+ "matrix multiplication with inputs and outputs sharded over the same axis. In\n",
+ "machine learning, `psum_scatter` can be used in tensor-parallel matrix\n",
+ "multiplies or fully-sharded data parallel gradient accumulation, as shown in\n",
+ "the examples to follow.\n",
+ "\n",
+ "## `ppermute`\n",
+ "\n",
+ "The `jax.lax.ppermute` collective provides the most direct way for\n",
+ "function instances to send data to one another. Given a mesh axis and a\n",
+ "list of `(source_index, destination_index)` pairs representing indices along\n",
+ "that mesh axis, `ppermute` sends its argument value from each source function\n",
+ "instance to each destination:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bc06e7b4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
+ "def f7(x_block):\n",
+ " sz = jax.lax.psum(1, 'i')\n",
+ " print('BEFORE:\\n', x_block)\n",
+ " y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])\n",
+ " print('AFTER:\\n', y_block)\n",
+ " return y_block\n",
+ "\n",
+ "y = f7(jnp.arange(8))\n",
+ "print('FINAL RESULT:\\n', y)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0f825fd1",
+ "metadata": {},
+ "source": [
+ "In this case, with just two function instances, each instance's value of\n",
+ "`y_block` is the other's value of `x_block`.\n",
+ "\n",
+ "Source indices and destination indices can't be repeated. If an index does not\n",
+ "appear as a destination, then the value of the corresponding function\n",
+ "instance's result is an array of zeros.\n",
+ "\n",
+ "A `collective_ref` reference implementation could look like\n",
+ "\n",
+ "```python\n",
+ "def ppermute_ref(i, x_blocks, perm):\n",
+ " results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks)\n",
+ " for src, dst in perm:\n",
+ " results[dst] = x_blocks[src]\n",
+ " return results\n",
+ "```\n",
+ "\n",
+ "Other collectives can be implemented efficiently, in terms of total\n",
+ "communication, using `ppermute`s where each function passes data only to its\n",
+ "neighbors. For example, we could implement `psum_scatter` using a sequence of\n",
+ "`ppermute`s and local additions this way:\n",
+ "\n",
+ "
\n",
+ "\n",
+ "Or, with a numerical example:\n",
+ "\n",
+ "
\n",
+ "\n",
+ "\n",
+ "Intuitively, on each iteration each function instance sends 'up' the value it\n",
+ "received on the previous iteration, and reduces (adds) the value it receives\n",
+ "this iteration. In code, it might look like this:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8b7aa515",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def psum_scatter(x, axis_name, *, tiled=False):\n",
+ " size = jax.lax.psum(1, axis_name)\n",
+ " idx = jax.lax.axis_index(axis_name) # function instance index along axis_name\n",
+ " if tiled:\n",
+ " x = x.reshape(size, -1, *x.shape[1:]) # split leading axis\n",
+ " shift = partial(jax.lax.ppermute, axis_name=axis_name,\n",
+ " perm=[(i, (i - 1) % size) for i in range(size)])\n",
+ " for i in range(1, size):\n",
+ " update = shift(x[(idx + i) % size])\n",
+ " x = x.at[(idx + i + 1) % size].add(update)\n",
+ " return x[idx]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "77cc7c7c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
+ "def f8(x_block):\n",
+ " print('BEFORE:\\n', x_block)\n",
+ " y_block = psum_scatter(x_block, 'i', tiled=True)\n",
+ " print('AFTER:\\n', y_block)\n",
+ " return y_block\n",
+ "\n",
+ "x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])\n",
+ "y = f8(x)\n",
+ "print('FINAL RESULT:\\n', y)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "61400a38",
+ "metadata": {},
+ "source": [
+ "On TPU, there are higher-dimensional variants of this algorithm to exploit\n",
+ "multiple bidirectional physical mesh axes.\n",
+ "\n",
+ "Notice that `psum_scatter` is the transpose of `all_gather`. Indeed, a way to\n",
+ "implement `all_gather` in terms of `ppermute` looks like the reverse of the\n",
+ "above process:\n",
+ "\n",
+ "
\n",
+ "\n",
+ "In deep learning, we might use `ppermute` when implementing SPMD pipeline\n",
+ "parallelism, where we divide our network along its depth into stages and\n",
+ "evaluate the applications of stages in parallel. Or we might use `ppermute` in\n",
+ "parallelizing the evaluation of convolutional layers, where we shard over\n",
+ "spatial axes and thus devices must communicate \"halos\" to each other. Or it\n",
+ "may be used under-the-hood in tensor-parallel matrix multiplies.\n",
+ "\n",
+ "## `all_to_all`\n",
+ "\n",
+ "A final collective is `all_to_all`, which is essentially a block matrix\n",
+ "transpose operating along one positional axis and one cross-device axis:\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6fa39069",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n",
+ "def f9(x_block):\n",
+ " print('BEFORE:\\n', x_block)\n",
+ " y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,\n",
+ " tiled=True)\n",
+ " print('AFTER:\\n', y_block)\n",
+ " return y_block\n",
+ "\n",
+ "x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])\n",
+ "y = f9(x)\n",
+ "print('FINAL RESULT:\\n', y)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "549af5f6",
+ "metadata": {},
+ "source": [
+ "The `split_axis` argument indicates which positional axis should be sharded\n",
+ "and partitioned across the mesh axis. The `concat_axis` argument indicates the\n",
+ "axis along which the communicated results should be concatenated or stacked.\n",
+ "\n",
+ "When `tiled=False` (the default), the `split_axis` axis size must equal the\n",
+ "size of the mesh axis named `axis_name`, and a new axis of that size is\n",
+ "created at position `concat_axis` for the stacked results. When `tiled=True`,\n",
+ "the `split_axis` axis size need only be evenly divisible by the size of the\n",
+ "mesh axis, and results are concatenated along the existing axis `concat_axis`.\n",
+ "\n",
+ "The `collective_ref` reference semantics when `split_axis=0` and\n",
+ "`concat_axis=0` might look like:\n",
+ "\n",
+ "```python\n",
+ "def all_to_all_ref(_, x_blocks, *, tiled=False):\n",
+ " axis_size = len(x_blocks)\n",
+ " if tiled:\n",
+ " splits = [jnp.array_split(x, axis_size) for x in x_blocks]\n",
+ " return [jnp.concatenate(s) for s in zip(*splits)]\n",
+ " else:\n",
+ " splits = [list(x) for x in x_blocks]\n",
+ " return [jnp.stack(s) for s in zip(*splits)]\n",
+ "```\n",
+ "\n",
+ "In deep learning, we might use `all_to_all` in mixture-of-expert routing,\n",
+ "where we first sort our local batch of examples according to which expert they\n",
+ "should go to, then apply an `all_to_all` to redistribute examples to experts.\n",
+ "\n",
+ "# Toy examples\n",
+ "\n",
+ "How might we use `shard_map` and collective communication in practice? These\n",
+ "examples, while simple, give some idea.\n",
+ "\n",
+ "## Matrix multiplies\n",
+ "\n",
+ "Parallelizing matrix multiplication is central in scaling up deep learning\n",
+ "models, both for training and for inference. When `jax.jit` automatically\n",
+ "parallelizes matrix multiplication, it can use one of several different\n",
+ "strategies, depending on matrix sizes, hardware details, and other factors. How\n",
+ "might we write some of those parallelized routines more explicitly using\n",
+ "`shard_map`? And how can we optimize them to get better compute/communication\n",
+ "overlap and thus improve FLOP utilization?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "86b85f05",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "\n",
+ "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n",
+ "from jax.experimental.shard_map import shard_map"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bcd9b561",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mesh = Mesh(jax.devices()[:4], ('i',))\n",
+ "\n",
+ "def device_put(x, pspec):\n",
+ " return jax.device_put(x, NamedSharding(mesh, pspec))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2e2b33b9",
+ "metadata": {},
+ "source": [
+ "### Example 1: `all-gather` on one side\n",
+ "\n",
+ "Consider performing a matrix multiplication where we shard the left-hand side\n",
+ "argument (can think: parameters) on its leading (non-contracting) dimension:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bc221220",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lhs_spec = P('i', None)\n",
+ "lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3e76bcfd",
+ "metadata": {},
+ "source": [
+ "And wee shard the right-hand side argument (can think: activations) on its\n",
+ "contracting dimension, with a similar sharding for the output:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "272bd303",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rhs_spec = P('i', None)\n",
+ "rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2691cf9c",
+ "metadata": {},
+ "source": [
+ "To perform this matrix multiplication, we can first all-gather the right-hand\n",
+ "side and then perform local matrix multiplies against the sharded left-hand\n",
+ "side:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8971775e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@jax.jit\n",
+ "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n",
+ " out_specs=rhs_spec)\n",
+ "def matmul_allgather(lhs_block, rhs_block):\n",
+ " rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)\n",
+ " return lhs_block @ rhs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a582e7ca",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "out = matmul_allgather(lhs, rhs)\n",
+ "print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "41eb2743",
+ "metadata": {},
+ "source": [
+ "That's great, but we're not getting any compute/communication overlap\n",
+ "here: before we can start the matmul, we need the all_gather to complete.\n",
+ "Here's a profile using the same code, but on larger example shapes (`(8192,\n",
+ "8192)` for `lhs` and `(8192, 1024)` for `rhs`):\n",
+ "\n",
+ "
\n",
+ "\n",
+ "We can get compute/communication overlap if instead of calling `all_gather` we\n",
+ "basically inline our above implementation of `all_gather` in terms of\n",
+ "`ppermute`, then interleave steps of the gather permutation with local matrix\n",
+ "multiplies:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9a6b952b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@jax.jit\n",
+ "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n",
+ " out_specs=rhs_spec)\n",
+ "def matmul_allgather_overlapped(lhs_block, rhs_block):\n",
+ " size = jax.lax.psum(1, 'i')\n",
+ " idx = jax.lax.axis_index('i')\n",
+ " shift = partial(jax.lax.ppermute, axis_name='i',\n",
+ " perm=[(i, (i + 1) % size) for i in range(size)])\n",
+ "\n",
+ " B = lhs_block.shape[1] // size\n",
+ " lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)\n",
+ "\n",
+ " out_block = lhs_blocks(idx) @ rhs_block\n",
+ " for i in range(1, size):\n",
+ " rhs_block = shift(rhs_block)\n",
+ " out_block += lhs_blocks((idx - i) % size) @ rhs_block\n",
+ " return out_block"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "cbe3cff0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "out = matmul_allgather_overlapped(lhs, rhs)\n",
+ "print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5d683ab3",
+ "metadata": {},
+ "source": [
+ "This implementation allows overlap between communication and computation, and\n",
+ "also avoids gathering a large intermediate onto each device. But on TPU it uses\n",
+ "only half the interconnect bandwidth by permuting in only one direction along\n",
+ "the ring. To permute bidirectionally, we just split the blocks in half and send\n",
+ "each half in each direction:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b777f21d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@jax.jit\n",
+ "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n",
+ " out_specs=rhs_spec)\n",
+ "def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):\n",
+ " size = jax.lax.psum(1, 'i')\n",
+ " idx = jax.lax.axis_index('i')\n",
+ " shift_up = partial(jax.lax.ppermute, axis_name='i',\n",
+ " perm=[(i, (i + 1) % size) for i in range(size)])\n",
+ " shift_dn = partial(jax.lax.ppermute, axis_name='i',\n",
+ " perm=[(i, (i - 1) % size) for i in range(size)])\n",
+ "\n",
+ " B = lhs_block.shape[1] // size // 2 # half-size blocks\n",
+ " lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1)\n",
+ "\n",
+ " rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)\n",
+ " out_block = lhs_blocks(idx, 0) @ rhs_block_lo\n",
+ " out_block += lhs_blocks(idx, 1) @ rhs_block_hi\n",
+ " for i in range(1, size):\n",
+ " rhs_block_lo = shift_up(rhs_block_lo)\n",
+ " rhs_block_hi = shift_dn(rhs_block_hi)\n",
+ " out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo\n",
+ " out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi\n",
+ " return out_block"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e40d1e8c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "out = matmul_allgather_overlapped_bidi(lhs, rhs)\n",
+ "print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "884f5535",
+ "metadata": {},
+ "source": [
+ "
\n",
+ "\n",
+ "In practice, to reduce compile times we would probably roll this into a\n",
+ "`jax.lax.fori_loop`. We might also have additional axes of parallelism\n",
+ "involved.\n",
+ "\n",
+ "### Example 2: `psum_scatter` the result\n",
+ "\n",
+ "Another sharding we might start with has both `lhs` and `rhs` sharded along\n",
+ "their contracting dimensions, with the output sharded like `rhs` again:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ee46a15f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lhs_spec = P(None, 'i')\n",
+ "lhs = device_put(lhs, lhs_spec)\n",
+ "\n",
+ "rhs_spec = P('i', None)\n",
+ "rhs = device_put(rhs, rhs_spec)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c1c7200f",
+ "metadata": {},
+ "source": [
+ "Here we can use a `reduce_scatter` to perform the contraction sum over shards:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "05d6ad68",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n",
+ " out_specs=rhs_spec)\n",
+ "def matmul_psumscatter(lhs_block, rhs_block):\n",
+ " out_summand = lhs_block @ rhs_block\n",
+ " return jax.lax.psum_scatter(out_summand, 'i', tiled=True)\n",
+ "\n",
+ "out = matmul_psumscatter(lhs, rhs)\n",
+ "print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cad74c82",
+ "metadata": {},
+ "source": [
+ "But the scattering communication must wait for the entire local matrix multiply\n",
+ "to finish before it can start. To get communication/computation overlap, we can\n",
+ "inline an implementation of `psum_scatter` in terms of `ppermute`, then\n",
+ "interleave the communication steps with local matrix multiplies:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "66dfac2d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n",
+ " out_specs=rhs_spec)\n",
+ "def matmul_psumscatter_overlapped(lhs_block, rhs_block):\n",
+ " size = jax.lax.psum(1, 'i')\n",
+ " idx = jax.lax.axis_index('i')\n",
+ " shift = partial(jax.lax.ppermute, axis_name='i',\n",
+ " perm=[(i, (i - 1) % size) for i in range(size)])\n",
+ " lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1]) # split 1st axis\n",
+ "\n",
+ " out_summand = lhs_block[(idx + 1) % size] @ rhs_block\n",
+ " for i in range(1, size):\n",
+ " out_summand = shift(out_summand)\n",
+ " out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block\n",
+ " return out_summand"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "042dde0a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "out = matmul_psumscatter_overlapped(lhs, rhs)\n",
+ "print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8c73f21e",
+ "metadata": {},
+ "source": [
+ "As in the previous example, to fully utilize interconnects on TPU, we'd run a\n",
+ "bidirectional version:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "52db4d9e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n",
+ " out_specs=rhs_spec)\n",
+ "def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):\n",
+ " size = jax.lax.psum(1, 'i')\n",
+ " idx = jax.lax.axis_index('i')\n",
+ " shift_up = partial(jax.lax.ppermute, axis_name='i',\n",
+ " perm=[(i, (i + 1) % size) for i in range(size)])\n",
+ " shift_dn = partial(jax.lax.ppermute, axis_name='i',\n",
+ " perm=[(i, (i - 1) % size) for i in range(size)])\n",
+ "\n",
+ " B = lhs_block.shape[0] // size // 2 # half-size blocks\n",
+ " lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0)\n",
+ "\n",
+ " out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block\n",
+ " out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block\n",
+ " for i in range(1, size):\n",
+ " out_summand_lo = shift_up(out_summand_lo)\n",
+ " out_summand_hi = shift_dn(out_summand_hi)\n",
+ " out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block\n",
+ " out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block\n",
+ " return jnp.concatenate([out_summand_lo, out_summand_hi])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c29971e8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "out = matmul_psumscatter_overlapped_bidi(lhs, rhs)\n",
+ "print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "60c2d2bc",
+ "metadata": {},
+ "source": [
+ "## Neural networks\n",
+ "\n",
+ "We can use `shard_map` to parallelize computation in neural networks, either by\n",
+ "itself or in combination with the automatic partitioning in `jax.jit`. This\n",
+ "section has a few examples based on this toy neural network and random data:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "981ad73a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "\n",
+ "def predict(params, inputs):\n",
+ " for W, b in params:\n",
+ " outputs = jnp.dot(inputs, W) + b\n",
+ " inputs = jax.nn.relu(outputs)\n",
+ " return outputs\n",
+ "\n",
+ "def loss(params, batch):\n",
+ " inputs, targets = batch\n",
+ " predictions = predict(params, inputs)\n",
+ " return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e6652af0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def init_layer(key, n_in, n_out):\n",
+ " k1, k2 = jax.random.split(key)\n",
+ " W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n",
+ " b = jax.random.normal(k2, (n_out,))\n",
+ " return W, b\n",
+ "\n",
+ "def init(key, layer_sizes, batch_size):\n",
+ " key, *keys = jax.random.split(key, len(layer_sizes))\n",
+ " params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n",
+ "\n",
+ " key, *keys = jax.random.split(key, 3)\n",
+ " inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n",
+ " targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n",
+ "\n",
+ " return params, (inputs, targets)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1db636cc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "layer_sizes = [784, 128, 128, 128, 128, 128, 8]\n",
+ "batch_size = 32\n",
+ "\n",
+ "params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "43f9e760",
+ "metadata": {},
+ "source": [
+ "Compare these examples with the purely [automatic partitioning examples in the\n",
+ "\"Distributed arrays and automatic partitioning\"\n",
+ "doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n",
+ "While in those automatic partitioning examples we don't need to edit the model\n",
+ "functions to use different parallelization strategies, with `shard_map` we\n",
+ "often do.\n",
+ "\n",
+ "### 8-way batch data parallelism\n",
+ "\n",
+ "The simplest multi-device parallelism strategy is to shard the batch of inputs\n",
+ "and targets over multiple devices, replicate the parameters over those devices,\n",
+ "and apply the model in parallel to those shards of data. To evaluate the total\n",
+ "loss, the devices need only communicate with a scalar-sized all-reduce-sum at\n",
+ "the end. (To evaluate the gradient of the loss, the devices must perform\n",
+ "all-reduce-sums of parameter gradients in the backward pass.)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d6417125",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from functools import partial\n",
+ "\n",
+ "from jax.sharding import NamedSharding, Mesh, PartitionSpec as P\n",
+ "from jax.experimental.shard_map import shard_map\n",
+ "from jax.experimental import mesh_utils\n",
+ "\n",
+ "devices = mesh_utils.create_device_mesh((8,))\n",
+ "\n",
+ "# replicate initial params on all devices, shard data batch over devices\n",
+ "mesh = Mesh(devices, ('batch',))\n",
+ "batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n",
+ "params = jax.device_put(params, NamedSharding(mesh, P()))\n",
+ "\n",
+ "# adapt the loss function to sum the losses across devices\n",
+ "def loss_dp(params, batch):\n",
+ " @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P(),\n",
+ " check_rep=False) # TODO remove check_rep=False\n",
+ " def loss_spmd(local_batch):\n",
+ " inputs, targets = local_batch\n",
+ " predictions = predict(params, inputs) # use reference 'predict`\n",
+ " local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))\n",
+ " return jax.lax.pmean(local_loss, 'batch')\n",
+ " return loss_spmd(batch)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9fe63185",
+ "metadata": {},
+ "source": [
+ "We can check that the loss and its gradients match the reference (base) model:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1f1a7155",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(jax.jit(loss)(params, batch))\n",
+ "print(jax.jit(loss_dp)(params, batch))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "89e0a24d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def allclose(a, b):\n",
+ " return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))\n",
+ "\n",
+ "print(allclose(jax.jit(jax.grad(loss))(params, batch),\n",
+ " jax.jit(jax.grad(loss_dp))(params, batch)))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0033ac3b",
+ "metadata": {},
+ "source": [
+ "We can print the compiler IR to inspect the gradient computation and verify\n",
+ "that the collective all-reduce-sum operations happen where we'd expect: at the\n",
+ "end of the forward pass to compute the loss value, and in the backward pass to\n",
+ "compute the total parameter gradients.\n",
+ "\n",
+ "### 8-way fully sharded data parallelism (FSDP)\n",
+ "\n",
+ "Another strategy is to additionally shard the parameters over the devices,\n",
+ "all-gathering each one when the full value is needed for the `jnp.dot` or bias\n",
+ "addition. Since we only have one full parameter in local device memory at a\n",
+ "time, rather than keeping all parameters in all device memories as in the\n",
+ "preceding DP example, we free up significant memory that we can use for larger\n",
+ "models or larger batch sizes. And because XLA will overlap computation and\n",
+ "inter-device communication, the wall-clock time doesn't suffer.\n",
+ "\n",
+ "So now we need collectives in two places: the model prediction function\n",
+ "`predict` needs to all-gather the parameters before they're used, and as in the\n",
+ "DP case the loss function needs to sum the local losses to compute the total\n",
+ "loss.\n",
+ "\n",
+ "There's one other ingredient we need: we don't want to store the fully gathered\n",
+ "parameters from the forward pass for use on the backward pass. Instead, we want\n",
+ "to gather them again on the backward pass. We can express that by using\n",
+ "`jax.remat` with a [custom\n",
+ "policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)\n",
+ "(or a `custom_vjp`), though XLA typically does that rematerialization\n",
+ "automatically.\n",
+ "\n",
+ "This general [FSDP\n",
+ "approach](https://engineering.fb.com/2021/07/15/open-source/fsdp/) is similar\n",
+ "to [weight update sharding (WUS)](https://arxiv.org/abs/2004.13336) and\n",
+ "[ZeRO-3](https://arxiv.org/abs/1910.02054)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f4538cd6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# shard data batch *and params* over devices\n",
+ "mesh = Mesh(devices, ('batch',))\n",
+ "batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n",
+ "params = jax.device_put(params, NamedSharding(mesh, P('batch')))\n",
+ "\n",
+ "# adapt the prediction function to gather weights just before their use,\n",
+ "# and to re-gather them on the backward pass (rather than saving them)\n",
+ "@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')\n",
+ "def predict_fsdp(params_frag, inputs):\n",
+ " for W_frag, b_frag in params_frag:\n",
+ " W = jax.lax.all_gather(W_frag, 'batch', tiled=True)\n",
+ " b = jax.lax.all_gather(b_frag, 'batch', tiled=True)\n",
+ " outputs = jnp.dot(inputs, W) + b\n",
+ " inputs = jax.nn.relu(outputs)\n",
+ " return outputs\n",
+ "\n",
+ "def loss_fsdp(params, batch):\n",
+ " @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())\n",
+ " def loss_spmd(local_params, local_batch):\n",
+ " inputs, targets = local_batch\n",
+ " predictions = predict_fsdp(local_params, inputs)\n",
+ " local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))\n",
+ " return jax.lax.pmean(local_loss, 'batch')\n",
+ " return loss_spmd(params, batch)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "30dd5f99",
+ "metadata": {},
+ "source": [
+ "Again we can check that the loss and its gradients match the reference model:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1035b6fc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(jax.jit(loss)(params, batch))\n",
+ "print(jax.jit(loss_fsdp)(params, batch))\n",
+ "\n",
+ "print(allclose(jax.jit(jax.grad(loss))(params, batch),\n",
+ " jax.jit(jax.grad(loss_fsdp))(params, batch)))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f88ddefe",
+ "metadata": {},
+ "source": [
+ "### 8-way tensor parallelism (TP)\n",
+ "\n",
+ "Usually we don't use tensor model parallelism by itself, but seeing it in\n",
+ "isolation is a good warmup on parallel matrix multiplication. It's also a good\n",
+ "example of using `shard_map` in a library function, called in a larger\n",
+ "`jit`-based computation.\n",
+ "\n",
+ "The parallelization idea is that we'll keep the data/activations sharded over\n",
+ "its feature axis (rather than its batch axis), and we'll similarly shard weight\n",
+ "matrices over their input-feature axis (and biases over their feature axis).\n",
+ "Then to perform the parallel matrix multiplication, we'll perform local matrix\n",
+ "multiplications followed by a `psum_scatter` to sum the local results and\n",
+ "efficiently scatter the result's shards."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7bd1fb92",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "devices = mesh_utils.create_device_mesh((8,))\n",
+ "mesh = Mesh(devices, ('feats',))\n",
+ "\n",
+ "batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))\n",
+ "params = jax.device_put(params, NamedSharding(mesh, P('feats')))\n",
+ "\n",
+ "def predict_tp(params, inputs):\n",
+ " for W, b in params:\n",
+ " outputs = gemm_tp(inputs, W, b)\n",
+ " inputs = jax.nn.relu(outputs)\n",
+ " return outputs\n",
+ "\n",
+ "@partial(shard_map, mesh=mesh,\n",
+ " in_specs=(P(None, 'feats'), P('feats', None), P('feats')),\n",
+ " out_specs=P(None, 'feats'))\n",
+ "def gemm_tp(inputs, W, b):\n",
+ " block_result = jnp.dot(inputs, W)\n",
+ " return jax.lax.psum_scatter(block_result, 'feats',\n",
+ " scatter_dimension=1, tiled=True) + b\n",
+ "\n",
+ "def loss_tp(params, batch):\n",
+ " inputs, targets = batch\n",
+ " predictions = predict_tp(params, inputs)\n",
+ " return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cf59d537",
+ "metadata": {},
+ "source": [
+ "### FSDP + TP, with `shard_map` at the top level\n",
+ "\n",
+ "We can compose these strategies together, using multiple axes of parallelism."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d4605705",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "devices = mesh_utils.create_device_mesh((4, 2))\n",
+ "mesh = Mesh(devices, ('batch', 'feats'))\n",
+ "\n",
+ "batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))\n",
+ "params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))\n",
+ "\n",
+ "# mostly same as previous predict_fsdp definition, except we call gemm_tp\n",
+ "@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')\n",
+ "def predict_fsdp_tp(params_frag, inputs):\n",
+ " for W_frag, b_frag in params_frag:\n",
+ " W = jax.lax.all_gather(W_frag, 'batch', tiled=True)\n",
+ " b = jax.lax.all_gather(b_frag, 'batch', tiled=True)\n",
+ " block_result = jnp.dot(inputs, W)\n",
+ " outputs = jax.lax.psum_scatter(block_result, 'feats',\n",
+ " scatter_dimension=1, tiled=True) + b\n",
+ " inputs = jax.nn.relu(outputs)\n",
+ " return outputs\n",
+ "\n",
+ "@partial(shard_map, mesh=mesh,\n",
+ " in_specs=(P(('feats', 'batch')), P('batch', 'feats')))\n",
+ "def loss_fsdp_tp(local_params, local_batch):\n",
+ " inputs, targets = local_batch\n",
+ " predictions = predict_fsdp_tp(local_params, inputs)\n",
+ " sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats')\n",
+ " return jax.lax.pmean(jnp.mean(sq_err), 'batch')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8220f75c",
+ "metadata": {},
+ "source": [
+ "Notice how we have to do _two_ collective reductions: one over `'feats'` and\n",
+ "one over `'batch'`. In the pure TP example, we didn't write the `'feats'`\n",
+ "reduction explicitly because we only used `shard_map` within `gemm_tp`; in the\n",
+ "caller `loss_tp`, the compiler automatically translated our use of `jnp.sum` to\n",
+ "perform a `psum` as needed given the sharded result returned by `predict_tp`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "30b87c53",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(jax.jit(loss)(params, batch))\n",
+ "print(jax.jit(loss_fsdp_tp)(params_, batch_))\n",
+ "\n",
+ "print(allclose(jax.jit(jax.grad(loss))(params, batch),\n",
+ " jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "94a352ca",
+ "metadata": {},
+ "source": [
+ "### SPMD pipeline parallelism (PP)\n",
+ "\n",
+ "With pipeline parallelism we aim to parallelize the evaluation of layers at\n",
+ "different depths in our network. For example, one device might compute the\n",
+ "application of the first layer while another device computes the application of\n",
+ "the second; when they finish, the first device passes its results to the second\n",
+ "while the second passes its results to the device responsible for the third\n",
+ "layer, and the process repeats. In general the number of pipeline stages may be\n",
+ "different from the number of layers, as each stage may be responsible for\n",
+ "multiple layers.\n",
+ "\n",
+ "With SPMD pipelining, we exploit the fact that most layers in the network apply\n",
+ "the computation, just with different parameter values. In particular, we can\n",
+ "stack together all the parameters except for those for the first and last\n",
+ "layers, then use a `shard_map` to map over blocks of those layer parameters,\n",
+ "where each block of parameters corresponds to a pipeline stage. We then use the\n",
+ "`jax.lax.ppermute` collective to shift data down the parallel pipeline.\n",
+ "\n",
+ "This particular pipelining strategy is essentially [the GPipe\n",
+ "strategy](https://arxiv.org/abs/1811.06965). There are several variants, as\n",
+ "well as quite different strategies, and which is appropriate can depend on the\n",
+ "speed of the networking between stages and batch sizes. But for this tutorial\n",
+ "we'll focus on just one strategy.\n",
+ "\n",
+ "First, we choose some pipeline parameters:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ec88fcdb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "L = len(params) - 2 # num layers, excluding first and last\n",
+ "N = batch_size # batch size\n",
+ "F = params[0][0].shape[1] # num features\n",
+ "\n",
+ "# choose some pipeline parameters\n",
+ "S = 2 # number of stages\n",
+ "B = 8 # size of each microbatch\n",
+ "assert L % S == 0, \"S (number of stages) must divide L (number of inner layers)\"\n",
+ "\n",
+ "# compute some useful quantities\n",
+ "M, ragged = divmod(N, B) # M is number of microbatches\n",
+ "assert not ragged, \"B (size of each microbatch) must divide total batch size\"\n",
+ "K, ragged = divmod(M, S) # K is microbatches per stage\n",
+ "assert not ragged, \"S (number of stages) must divide number of microbatches\"\n",
+ "print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total')\n",
+ "print(f'{B} examples per microbatch, {M} microbatches total')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e35bd0d6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mesh = Mesh(jax.devices()[:S], ('stages',))\n",
+ "\n",
+ "def predict_pp(params, inputs):\n",
+ " (W_first, b_first), inner_params, (W_last, b_last) = params\n",
+ " inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first)\n",
+ " inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]),\n",
+ " inner_params, inputs)\n",
+ " outputs = jnp.dot(inputs, W_last) + b_last\n",
+ " return outputs\n",
+ "\n",
+ "@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),\n",
+ " out_specs=P())\n",
+ "def loss_pp(params, batch):\n",
+ " inputs, targets = batch\n",
+ " predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)\n",
+ " local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))\n",
+ " return jax.lax.pmean(local_loss, 'stages')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "257b88e1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def spmd_pipeline(fn, stage_params, inputs):\n",
+ " stage = jax.lax.axis_index('stages')\n",
+ " outputs = jnp.zeros_like(inputs) * jnp.nan\n",
+ " state = jnp.zeros((L // S, B, F)) * jnp.nan\n",
+ " for i in range(M+L-1):\n",
+ " state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0]))\n",
+ " state = jax.vmap(fn)(stage_params, state)\n",
+ " outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K]))\n",
+ " state, inputs, outputs = shift(i, state, inputs, outputs)\n",
+ " outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)])\n",
+ " return outputs\n",
+ "\n",
+ "def shift(i, state, inputs, outputs):\n",
+ " sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)])\n",
+ " state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1))\n",
+ " if (i % K) == (-1 % K):\n",
+ " inputs = sh(inputs, +1)\n",
+ " if ((i-L+1) % K) == (-1 % K):\n",
+ " outputs = sh(outputs, +1)\n",
+ " return state, inputs, outputs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "12a478a3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "first_params, *inner_params, last_params = params\n",
+ "Ws, bs = zip(*inner_params)\n",
+ "params_stacked = jnp.stack(Ws), jnp.stack(bs)\n",
+ "first_params = jax.device_put(first_params, NamedSharding(mesh, P()))\n",
+ "params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages')))\n",
+ "last_params = jax.device_put(last_params, NamedSharding(mesh, P()))\n",
+ "params_ = first_params, params_stacked, last_params\n",
+ "\n",
+ "batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7a3086fb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(jax.jit(loss)(params, batch))\n",
+ "print(jax.jit(loss_pp)(params_, batch_))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9308c874",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "_ = jax.jit(jax.grad(loss_pp))(params_, batch_) # don't crash"
+ ]
+ }
+ ],
+ "metadata": {
+ "jupytext": {
+ "cell_metadata_filter": "-all",
+ "formats": "ipynb,md:myst",
+ "main_language": "python"
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md
new file mode 100644
index 000000000..9a8ac1bf9
--- /dev/null
+++ b/docs/notebooks/shard_map.md
@@ -0,0 +1,1373 @@
+---
+jupytext:
+ cell_metadata_filter: -all
+ formats: ipynb,md:myst
+ main_language: python
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.0
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+# Intro
+
+`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.
+
+`shard_map` is complementary to, and comopsable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.
+
+If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))
+
+By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies.
+
+## So, let's see a `shard_map`!
+
+Without further ado, here's a toy example:
+
+```{code-cell}
+from functools import partial
+
+import jax
+import jax.numpy as jnp
+
+from jax.sharding import Mesh, PartitionSpec as P
+from jax.experimental import mesh_utils
+from jax.experimental.shard_map import shard_map
+```
+
+```{code-cell}
+devices = mesh_utils.create_device_mesh((4, 2))
+mesh = Mesh(devices, axis_names=('x', 'y'))
+
+a = jnp.arange( 8 * 16.).reshape(8, 16)
+b = jnp.arange(16 * 4.).reshape(16, 4)
+
+@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
+ out_specs=P('x', None))
+def matmul_basic(a_block, b_block):
+ # a_block: f32[2, 8]
+ # b_block: f32[8, 4]
+ c_partialsum = jnp.dot(a_block, b_block)
+ c_block = jax.lax.psum(c_partialsum, 'y')
+ # c_block: f32[2, 4]
+ return c_block
+
+c = matmul_basic(a, b) # c: f32[8, 4]
+```
+
+This function computes a matrix multiply in parallel by performing local block matrix multipiles followed by a collective sum operation. We can check the result is correct:
+
+```{code-cell}
+from jax.tree_util import tree_map, tree_all
+
+def allclose(a, b):
+ return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
+
+allclose(c, jnp.dot(a, b))
+```
+
+The result is sharded along its rows:
+
+```{code-cell}
+jax.debug.visualize_array_sharding(c)
+```
+
+At a high level, `shard_map` is kind of like `vmap` or `pmap`, in that we're
+mapping a function over pieces of array data, but notice that
+* `shard_map` slices up inputs into blocks (and the output is formed by concatenating result blocks), keeping the rank the same, whereas `vmap` would reduce the rank by mapping away an axis;
+* the `mesh` argument lets us control precise device placement of computation and results;
+* we're mapping over multiple data axes at once, and setting up multiple axis names for collectives (both `'x'` and `'y'` here);
+* since we're not using `jax.jit` yet, everything is eagerly evaluated, and we can even `print` intermediate values for debugging.
+
+The above code is performing the same computation as this `jax.jit` automatic parallelization code:
+
+```{code-cell}
+from jax.sharding import NamedSharding
+
+a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
+b = jax.device_put(b, NamedSharding(mesh, P('y', None)))
+
+@jax.jit
+def matmul_reference(a, b):
+ c = jnp.dot(a, b)
+ return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))
+
+c_ref = matmul_reference(a, b)
+allclose(c_ref, jnp.dot(a, b))
+```
+
+We can think of `shard_map` as performing a `device_put` or
+`with_sharding_constraint` on its inputs according to its `mesh` and `in_specs`
+arguments, so the blocks over which `matmul_basic` operates are the same as in
+`matmul_reference`:
+
+```{code-cell}
+print('a blocks:'); jax.debug.visualize_array_sharding(a)
+print('b blocks:'); jax.debug.visualize_array_sharding(b)
+print('c blocks:'); jax.debug.visualize_array_sharding(c)
+```
+
+## Slow down, start with the basics!
+
+### Rank-reducing vs rank-preserving maps
+
+We can think of `vmap` and `pmap` as unstacking each array input along an axis
+(e.g. unpacking a 2D matrix into its 1D rows), applying its body function to
+each piece, and stacking the results back together, at least when collectives
+aren't involved:
+
+```{code-cell}
+def check_vmap(f, xs):
+ ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
+ expected = jnp.stack([f(x) for x in xs]) # vmap reference semantics
+ print(allclose(ans, expected))
+
+check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))
+```
+
+For example, if `xs` had shape `f32[8,5]` then each `x` would have shape
+`f32[5]`, and if each `f(x)` had shape `f32[3,7]` then the final stacked result
+`vmap(f)(xs)` would have shape `f32[8,3,7]`. That is, each application of the
+body function `f` takes as argument inputs with one fewer axis than the
+corresponding argument to `vmap(f)`. We can say these are _rank-reducing maps_
+with unstacking/stacking of inputs/outputs.
+
+The number of logical applications of `f`, or _instances_ of `f`, is determined
+by the size of the input axis being mapped over: for example, if we map over an
+input axis of size 8, semantically we get 8 logical applications of the
+function.
+
+In contrast, `shard_map` does not have this rank-reducing behavior. Instead, we
+can think of it as slicing (or "unconcatenating") along input axes into blocks,
+applying the body function, and concatenating the results back together (again
+when collectives aren't involved):
+
+```{code-cell}
+import numpy as np
+devices = np.array(jax.devices()[:4])
+mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4
+
+def check_shmap(f, y):
+ ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)
+ expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
+ print(allclose(ans, expected))
+
+check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))
+```
+
+Recall that jnp.split slices its input into equally-sized blocks with the same
+rank, so that if in the above example `y` had shape `f32[8,5]` then each
+`y_blk` would have shape `f32[2,5]`, and if each `f(y_blk)` had shape
+`f32[3,7]` then the final concatenated result `shard_map(f, ...)(y)` would have
+shape `f32[12,7]`. So `shard_map` maps over _shards_, or blocks, of its inputs.
+We can say it's a _rank-preserving map_ with unconcatenating/concatenating of
+its inputs/outputs.
+
+The number of logical applications of `f` is determined by the mesh size, not
+by any input axis size: for example, if we have a mesh of total size 4 (i.e.
+over 4 devices) then semantically we get 4 logical applications of the
+function, corresponding to the 4 devices physically computing them.
+
+### Controlling how each input is split (unconcatenated) and tiled with `in_specs`
+
+Each of the `in_specs` identifies some of the corresponding input array's axes
+with mesh axes by name using `PartitionSpec`s, representing how to split (or
+unconcatenate) that input into the blocks to which the body function is
+applied. That identification determines the shard sizes; when an input axis is
+identified with a mesh axis, the input is split (unconcatenated) along that
+logical axis into a number of pieces equal to the corresponding mesh axis size.
+(It's an error if the corresponding mesh axis size does not evenly divide the
+input array axis size.) If an input's pspec does not mention a mesh axis name,
+then there's no splitting over that mesh axis. For example:
+
+```{code-cell}
+devices = mesh_utils.create_device_mesh((4, 2))
+mesh = Mesh(devices, ('i', 'j'))
+
+@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
+def f1(x_block):
+ print(x_block.shape) # prints (3, 12)
+ return x_block
+
+x1 = jnp.arange(12 * 12).reshape(12, 12)
+y = f1(x1)
+```
+
+Here, because the input pspec did not mention the mesh axis name `'j'`, no
+input array axis is split over that mesh axis; similarly, because the second
+axis of the input array is not identified with (and hence split over) any mesh
+axis, application of `f1` gets a full view of the input along that axis.
+
+When a mesh axis is not mentioned in an input pspec, we can always rewrite to a
+less efficient program where all mesh axes are mentioned but the caller
+performs a `jnp.tile`, for example:
+
+```{code-cell}
+@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
+def f2(x_block):
+ print(x_block.shape)
+ return x_block
+
+x = jnp.arange(12 * 12).reshape(12, 12)
+x_ = jnp.tile(x, (1, mesh.shape['j'])) # x_ has shape (12, 24)
+y = f2(x_) # prints (3,12), and f1(x) == f2(x_)
+```
+
+In other words, because each input pspec can mention each mesh axis name zero
+or one times, rather than having to mention each name exactly once, we can say
+that in addition to the `jnp.split` built into its input, `shard_map` also has
+a `jnp.tile` built into its input, at least logically (though the tiling may
+not need to be carried out physically, depending on the arguments' physical
+sharding layout). The tiling to use is not unique; we could also have tiled
+along the first axis, and used the pspec `P(('j', 'i'), None)`.
+
+Physical data movement is possible on inputs, as each device needs to have a
+copy of the appropriate data.
+
+### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs`
+
+Analogously to the input side, each of the `out_specs` identifies some of the
+corresponding output array's axes with mesh axes by name, representing how the
+output blocks (one for each application of the body function, or equivalently
+one for each physical device) should be assembled back together to form the
+final output value. For example, in both the `f1` and `f2` examples above the
+`out_specs` indicate we should form the final output by concatenating together
+the block results along both axes, resulting in both cases an array `y` of
+shape `(12, 24)`. (It's an error if an output shape of the body function, i.e.
+an output block shape, has a rank too small for the concatenation described by
+the corresponding output pspec.)
+
+When a mesh axis name is not mentioned in an output pspec, it represents an
+un-tiling: when the user writes an output pspec which does not mention one of
+the mesh axis names, they promise that the output blocks are equal along that
+mesh axis, and so only one block along that axis is used in the output (rather
+than concatenating all the blocks together along that mesh axis). For example,
+using the same mesh as above:
+
+```{code-cell}
+x = jnp.array([[3.]])
+
+z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()
+print(z) # prints the same as jnp.tile(x, (4, 2))
+
+z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()
+print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))
+
+z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()
+print(z) # prints the same as jnp.tile(x, (1, 1)), or just x
+```
+
+The body function closing over an array value is equivalent to passing it as an
+augment with a corresponding input pspec of P(None, None). As another example,
+following more closely to the other examples above:
+
+```{code-cell}
+@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))
+def f3(x_block):
+ return jax.lax.psum(x_block, 'j')
+
+x = jnp.arange(12 * 12).reshape(12, 12)
+y3 = f3(x)
+print(y3.shape)
+```
+
+The result has a second axis size of 6, half the size of the input's second
+axis. In this case, the un-tile expressed by not mentioning the mesh axis name
+`'j'` in the output pspec was safe because of the collective `psum`, which
+ensures each output block is equal along the corresponding mesh axis. Here are
+two more examples where we vary which mesh axes are mentioned in the output
+pspec:
+
+```{code-cell}
+@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
+def f4(x_block):
+ return jax.lax.psum(x_block, 'i')
+
+x = jnp.arange(12 * 12).reshape(12, 12)
+y4 = f4(x)
+print(y4.shape) # (3,12)
+
+
+@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))
+def f5(x_block):
+ return jax.lax.psum(x_block, ('i', 'j'))
+
+y5 = f5(x)
+print(y5.shape) # (3,6)
+```
+
+On the physical side, not mentioning a mesh axis name in an output pspec
+assembles an `Array` from the output device buffers with replicated layout
+along that mesh axis.
+
+There is no runtime check that the output blocks are actually equal along a
+mesh axis to be un-tiled along, or equivalently that the corresponding physical
+buffers have equal values and thus can be interpreted as a replicated layout
+for a single logical array. But we can provide a static check mechanism which
+raises an error on all potentially-incorrect programs.
+
+Because the `out_specs` can mention mesh axis names zero or one times, and
+because they can be mentioned in any order, we can say that in addition to the
+`jnp.concatenate` built into its output, `shard_map` also has both an _untile_
+and a _block transpose_ built into its output.
+
+Physical data movement is not possible on outputs, no matter the output pspec.
+Instead, `out_specs` just encodes how to assemble the block outputs into
+`Array`s, or physically how to interpret the buffers across devices as the
+physical layout of a single logical `Array`.
+
+# API Specification
+
+```python
+from jax.sharding import Mesh
+Specs = PyTree[PartitionSpec]
+
+def shard_map(
+ f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
+ auto: collections.abc.Set[AxisName] = frozenset([]),
+ check_rep: bool = True,
+) -> Callable:
+ ...
+```
+where:
+* communication collectives like `psum` in the body of `f` can mention the axis names of `mesh`;
+* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`;
+* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;
+* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the obdy, as in the caller, rather than manually;
+* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)).
+
+The shapes of the arguments passed to `f` have the same ranks as the arguments
+passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed
+from the shape `shape` of the corresponding argument to `shard_map`-of-`f` and
+the corresponding `PartitionSpec` `spec` as roughly
+`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`.
+
+# Collectives tutorial
+
+A `shard_map` need not be a pure map: function applications can communicate
+with each other via _collectives_, using axis names defined in the `mesh`
+argument.
+
+Recall that `shard_map` maps a function over shards, or blocks, of input data,
+so that this:
+
+```python
+mesh = Mesh(jax.devices(), ('i',))
+x = jnp.arange(16.)
+f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))
+y = f_shmapped(x)
+```
+
+Computes the same values, evaluating applications of `f` to the same argument
+values, as this reference function:
+
+```python
+def f_shmapped_ref(x):
+ x_blocks = jnp.array_split(x, mesh.shape[0])
+ y_blocks = [f(x_blk) for x_blk in x_blocks]
+ return jnp.concatenate(y_blocks)
+```
+
+We call these applications of `f` to different argument shards _function
+instances_. Each function instance is executed on a different device (or subset
+of devices).
+
+These reference semantics work when `f` has no communication collectives in
+it. But what if we want the function instances to communicate, corresponding
+to having cross-device communication? That is, what are the reference
+semantics when `f` contains a collective? Say `f` has just one collective, and
+is of the form
+
+```python
+def f(x_blk):
+ z_blk = f_part1(x_blk)
+ u_blk = collective(z_blk, axis_name)
+ v_blk = f_part2(x_blk, z_blk, u_blk)
+ return v_blk
+```
+
+where we're assuming there's only one mesh axis we're mapping over, and
+`axis_name` is the corresponding name for it. Then the reference semantics
+would look more like:
+
+```python
+def f_shmapped_ref(x):
+ x_blocks = jnp.array_split(x, mesh.shape[0])
+ z_blocks = [f_part1(x_blk) for x_blk in x_blocks]
+ u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))]
+ v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk
+ in zip(x_blocks, z_blocks, u_blocks)]
+ return jnp.concatenate(v_blocks)
+```
+
+Notice that `collective_ref` might depend on all the `z_blocks`. That is,
+while `f_part1` and `f_part2` are mapped over blocks independently, a
+collective introduces some amount of cross-block dependence. Physically, that
+means communication across devices. Exactly what communication happens, and
+what values are computed, depend on the collective.
+
+## `psum`
+
+The simplest collective may be `jax.lax.psum`, which computes an
+all-reduce-sum along a device mesh axis (or multiple axes).
+Here's a toy example:
+
+
+
+```{code-cell}
+import jax
+import jax.numpy as jnp
+from jax import lax
+
+from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
+from jax.experimental.shard_map import shard_map
+```
+
+```{code-cell}
+mesh1d = Mesh(jax.devices()[:4], ('i',))
+
+@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))
+def f1(x_block):
+ print('BEFORE:\n', x_block)
+ y_block = jax.lax.psum(x_block, 'i')
+ print('AFTER:\n', y_block)
+ return y_block
+```
+
+```{code-cell}
+x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
+y = f1(x)
+print('FINAL RESULT:\n', y)
+```
+
+The prints show that each function application starts with its own chunk of
+the argument value `x_block`. After the `psum`, each function application has
+the same value of `y_block`, computed by summing the applications' `x_block`
+values together.
+
+In the case where there's a single axis name in the computation, we could say
+that the `collective_ref` reference implementation for `psum` is
+
+```python
+def psum_ref(_, x_blocks):
+ tot = sum(x_blocks)
+ return [tot] * len(x_blocks)
+```
+
+Notice also that because `f1` returns `y_block`, the result of a `psum` over
+`'i'`, we can use `out_specs=P()` so the caller gets a single logical copy of
+the result value, rather than a tiled result.
+
+When there is more than one mesh axis, we can perform a `psum` over
+each one separately, or over multiple axes at once:
+
+```{code-cell}
+mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))
+
+@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
+def f2(x_block):
+ print('BEFORE:\n', x_block)
+ y_block = jax.lax.psum(x_block, 'i')
+ print('AFTER:\n', y_block)
+ return y_block
+
+y = f2(jnp.arange(16).reshape(4, 4))
+print('FINAL RESULT:\n', y)
+```
+
+By applying a `psum` over mesh axis `'i'`, we get values of `y_block` which
+are equal along axis '`i'`, but not axis `'j'`. (So we can use
+`out_specs=P(None, 'j')` to get a single logical result along that axis.)
+
+If we apply the `psum` over both axes, the `y_block` value is equal along both
+axes:
+
+```{code-cell}
+@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))
+def f3(x_block):
+ print('BEFORE:\n', x_block)
+ y_block = jax.lax.psum(x_block, ('i', 'j'))
+ print('AFTER:\n', y_block)
+ return y_block
+
+y = f3(jnp.arange(16).reshape(4, 4))
+print('FINAL RESULT:\n', y)
+```
+
+In machine learning, we often use `psum` to compute total losses or, when we
+have a `grad` inside the `shard_map`ped function body, total gradients.
+
+In the sequel, we'll see how `psum` can be implemented in terms of other
+primitives, which gives some intuition about its communication cost.
+
+## `all_gather`
+
+Another fundamental operation is gathering array shards along an axis, so that
+each function application has a full copy of the data along that axis:
+
+
+
+```{code-cell}
+@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
+def f4(x_block):
+ print('BEFORE:\n', x_block)
+ y_block = jax.lax.all_gather(x_block, 'i', tiled=True)
+ print('AFTER:\n', y_block)
+ return y_block
+
+x = jnp.array([3, 9, 5, 2])
+y = f4(x)
+print('FINAL RESULT:\n', y)
+```
+
+The prints show that each function application again starts with its own chunk
+of the argument value `x_block`. After the `all_gather`, they have a common
+value, computed by concatenating the values of `x_block`.
+
+(Notice that we actually can't set `out_specs=P()` here. For technical
+reasons related to automatic differentiation, we consider the output of
+`all_gather` not to be guaranteed invariant across devices. If we wanted it to
+be guaranteed invariant, we could use `jax.lax.all_gather_invariant`, or in
+this case we could just avoid doing the `all_gather` in the function body and
+instead just use `out_specs=P('i')` to perform the concatenation.)
+
+When `tiled=False` (the default), results are stacked along a new axis instead
+of concatenated:
+
+```{code-cell}
+@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
+def f5(x_block):
+ print('BEFORE:\n', x_block)
+ y_block = jax.lax.all_gather(x_block, 'i', tiled=False)
+ print('AFTER:\n', y_block)
+ return y_block
+
+y = f5(x)
+print('FINAL RESULT:\n', y)
+```
+
+We could write the `collective_ref` reference semantics function for
+`all_gather` as
+
+```python
+def all_gather_ref(_, x_blocks, *, tiled=False):
+ combine = jnp.concatenate if tiled else jnp.stack
+ return [combine(x_blocks)] * len(x_blocks)
+```
+
+In deep learning, we might use `all_gather`s on parameters in fully sharded
+data parallelism (FSDP).
+
+# psum_scatter
+
+The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like
+`psum` except each function instance gets only one shard of the result:
+
+
+
+```{code-cell}
+@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
+def f6(x_block):
+ print('BEFORE:\n', x_block)
+ y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)
+ print('AFTER:\n', y_block)
+ return y_block
+
+x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
+y = f6(x)
+print('FINAL RESULT:\n', y)
+```
+
+As shown by the prints, each resulting `y_block` has a smaller size than the
+argument `x_block`, unlike with `psum`. Moreover, compared to `psum`, here
+each `y_block` only represents a slice of the sum of the `x_block`s across
+function instances. (Even though each function instance gets only one shard of
+the sum, the final output `y` is the same as in the `psum` example because
+here we use `out_specs=P('i')` to concatenate each function instance's
+output.)
+
+In terms of what values are computed, a `collective_ref` reference
+implementation might look like:
+
+```python
+def psum_scatter_ref(i, x_blocks, *, tiled=False):
+ axis_size = len(x_blocks)
+ tot = sum(x_blocks)
+ if tiled:
+ tot = tot.reshape(axis_size, -1, *tot.shape[1:]) # split leading axis
+ return [tot[i] for i in range(tot.shape[0])]
+```
+
+It's not captured in the semantics reference implementation, but
+`psum_scatter` is useful because these results can be computed more
+efficiently, with less communication, than a full `psum`. In fact, one way to
+think of `psum_scatter` is as "the first half of a `psum`, before an
+`all_gather`". That is, one way to implement `psum` is:
+
+```python
+def psum(x, axis_name):
+ summed_chunk = jax.lax.psum_scatter(x, axis_name)
+ return jax.lax.all_gather(summed_chunk, axis_name)
+```
+
+Indeed, this implementation is often used on both TPU and GPU!
+
+The reason `psum_scatter` can require about half the communication as a full
+`psum` is illustrated the `ppermute` section.
+
+Another intuition is that we can use `psum_scatter` to implement a distributed
+matrix multiplication with inputs and outputs sharded over the same axis. In
+machine learning, `psum_scatter` can be used in tensor-parallel matrix
+multiplies or fully-sharded data parallel gradient accumulation, as shown in
+the examples to follow.
+
+## `ppermute`
+
+The `jax.lax.ppermute` collective provides the most direct way for
+function instances to send data to one another. Given a mesh axis and a
+list of `(source_index, destination_index)` pairs representing indices along
+that mesh axis, `ppermute` sends its argument value from each source function
+instance to each destination:
+
+```{code-cell}
+@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
+def f7(x_block):
+ sz = jax.lax.psum(1, 'i')
+ print('BEFORE:\n', x_block)
+ y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])
+ print('AFTER:\n', y_block)
+ return y_block
+
+y = f7(jnp.arange(8))
+print('FINAL RESULT:\n', y)
+```
+
+In this case, with just two function instances, each instance's value of
+`y_block` is the other's value of `x_block`.
+
+Source indices and destination indices can't be repeated. If an index does not
+appear as a destination, then the value of the corresponding function
+instance's result is an array of zeros.
+
+A `collective_ref` reference implementation could look like
+
+```python
+def ppermute_ref(i, x_blocks, perm):
+ results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks)
+ for src, dst in perm:
+ results[dst] = x_blocks[src]
+ return results
+```
+
+Other collectives can be implemented efficiently, in terms of total
+communication, using `ppermute`s where each function passes data only to its
+neighbors. For example, we could implement `psum_scatter` using a sequence of
+`ppermute`s and local additions this way:
+
+
+
+Or, with a numerical example:
+
+
+
+
+Intuitively, on each iteration each function instance sends 'up' the value it
+received on the previous iteration, and reduces (adds) the value it receives
+this iteration. In code, it might look like this:
+
+```{code-cell}
+def psum_scatter(x, axis_name, *, tiled=False):
+ size = jax.lax.psum(1, axis_name)
+ idx = jax.lax.axis_index(axis_name) # function instance index along axis_name
+ if tiled:
+ x = x.reshape(size, -1, *x.shape[1:]) # split leading axis
+ shift = partial(jax.lax.ppermute, axis_name=axis_name,
+ perm=[(i, (i - 1) % size) for i in range(size)])
+ for i in range(1, size):
+ update = shift(x[(idx + i) % size])
+ x = x.at[(idx + i + 1) % size].add(update)
+ return x[idx]
+```
+
+```{code-cell}
+@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
+def f8(x_block):
+ print('BEFORE:\n', x_block)
+ y_block = psum_scatter(x_block, 'i', tiled=True)
+ print('AFTER:\n', y_block)
+ return y_block
+
+x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
+y = f8(x)
+print('FINAL RESULT:\n', y)
+```
+
+On TPU, there are higher-dimensional variants of this algorithm to exploit
+multiple bidirectional physical mesh axes.
+
+Notice that `psum_scatter` is the transpose of `all_gather`. Indeed, a way to
+implement `all_gather` in terms of `ppermute` looks like the reverse of the
+above process:
+
+
+
+In deep learning, we might use `ppermute` when implementing SPMD pipeline
+parallelism, where we divide our network along its depth into stages and
+evaluate the applications of stages in parallel. Or we might use `ppermute` in
+parallelizing the evaluation of convolutional layers, where we shard over
+spatial axes and thus devices must communicate "halos" to each other. Or it
+may be used under-the-hood in tensor-parallel matrix multiplies.
+
+## `all_to_all`
+
+A final collective is `all_to_all`, which is essentially a block matrix
+transpose operating along one positional axis and one cross-device axis:
+
+
+
+```{code-cell}
+@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
+def f9(x_block):
+ print('BEFORE:\n', x_block)
+ y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,
+ tiled=True)
+ print('AFTER:\n', y_block)
+ return y_block
+
+x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
+y = f9(x)
+print('FINAL RESULT:\n', y)
+```
+
+The `split_axis` argument indicates which positional axis should be sharded
+and partitioned across the mesh axis. The `concat_axis` argument indicates the
+axis along which the communicated results should be concatenated or stacked.
+
+When `tiled=False` (the default), the `split_axis` axis size must equal the
+size of the mesh axis named `axis_name`, and a new axis of that size is
+created at position `concat_axis` for the stacked results. When `tiled=True`,
+the `split_axis` axis size need only be evenly divisible by the size of the
+mesh axis, and results are concatenated along the existing axis `concat_axis`.
+
+The `collective_ref` reference semantics when `split_axis=0` and
+`concat_axis=0` might look like:
+
+```python
+def all_to_all_ref(_, x_blocks, *, tiled=False):
+ axis_size = len(x_blocks)
+ if tiled:
+ splits = [jnp.array_split(x, axis_size) for x in x_blocks]
+ return [jnp.concatenate(s) for s in zip(*splits)]
+ else:
+ splits = [list(x) for x in x_blocks]
+ return [jnp.stack(s) for s in zip(*splits)]
+```
+
+In deep learning, we might use `all_to_all` in mixture-of-expert routing,
+where we first sort our local batch of examples according to which expert they
+should go to, then apply an `all_to_all` to redistribute examples to experts.
+
+# Toy examples
+
+How might we use `shard_map` and collective communication in practice? These
+examples, while simple, give some idea.
+
+## Matrix multiplies
+
+Parallelizing matrix multiplication is central in scaling up deep learning
+models, both for training and for inference. When `jax.jit` automatically
+parallelizes matrix multiplication, it can use one of several different
+strategies, depending on matrix sizes, hardware details, and other factors. How
+might we write some of those parallelized routines more explicitly using
+`shard_map`? And how can we optimize them to get better compute/communication
+overlap and thus improve FLOP utilization?
+
+```{code-cell}
+import jax
+import jax.numpy as jnp
+
+from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
+from jax.experimental.shard_map import shard_map
+```
+
+```{code-cell}
+mesh = Mesh(jax.devices()[:4], ('i',))
+
+def device_put(x, pspec):
+ return jax.device_put(x, NamedSharding(mesh, pspec))
+```
+
+### Example 1: `all-gather` on one side
+
+Consider performing a matrix multiplication where we shard the left-hand side
+argument (can think: parameters) on its leading (non-contracting) dimension:
+
+```{code-cell}
+lhs_spec = P('i', None)
+lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)
+```
+
+And wee shard the right-hand side argument (can think: activations) on its
+contracting dimension, with a similar sharding for the output:
+
+```{code-cell}
+rhs_spec = P('i', None)
+rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)
+```
+
+To perform this matrix multiplication, we can first all-gather the right-hand
+side and then perform local matrix multiplies against the sharded left-hand
+side:
+
+```{code-cell}
+@jax.jit
+@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
+ out_specs=rhs_spec)
+def matmul_allgather(lhs_block, rhs_block):
+ rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)
+ return lhs_block @ rhs
+```
+
+```{code-cell}
+out = matmul_allgather(lhs, rhs)
+print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
+```
+
+That's great, but we're not getting any compute/communication overlap
+here: before we can start the matmul, we need the all_gather to complete.
+Here's a profile using the same code, but on larger example shapes (`(8192,
+8192)` for `lhs` and `(8192, 1024)` for `rhs`):
+
+
+
+We can get compute/communication overlap if instead of calling `all_gather` we
+basically inline our above implementation of `all_gather` in terms of
+`ppermute`, then interleave steps of the gather permutation with local matrix
+multiplies:
+
+```{code-cell}
+@jax.jit
+@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
+ out_specs=rhs_spec)
+def matmul_allgather_overlapped(lhs_block, rhs_block):
+ size = jax.lax.psum(1, 'i')
+ idx = jax.lax.axis_index('i')
+ shift = partial(jax.lax.ppermute, axis_name='i',
+ perm=[(i, (i + 1) % size) for i in range(size)])
+
+ B = lhs_block.shape[1] // size
+ lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)
+
+ out_block = lhs_blocks(idx) @ rhs_block
+ for i in range(1, size):
+ rhs_block = shift(rhs_block)
+ out_block += lhs_blocks((idx - i) % size) @ rhs_block
+ return out_block
+```
+
+```{code-cell}
+out = matmul_allgather_overlapped(lhs, rhs)
+print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
+```
+
+This implementation allows overlap between communication and computation, and
+also avoids gathering a large intermediate onto each device. But on TPU it uses
+only half the interconnect bandwidth by permuting in only one direction along
+the ring. To permute bidirectionally, we just split the blocks in half and send
+each half in each direction:
+
+```{code-cell}
+@jax.jit
+@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
+ out_specs=rhs_spec)
+def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):
+ size = jax.lax.psum(1, 'i')
+ idx = jax.lax.axis_index('i')
+ shift_up = partial(jax.lax.ppermute, axis_name='i',
+ perm=[(i, (i + 1) % size) for i in range(size)])
+ shift_dn = partial(jax.lax.ppermute, axis_name='i',
+ perm=[(i, (i - 1) % size) for i in range(size)])
+
+ B = lhs_block.shape[1] // size // 2 # half-size blocks
+ lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1)
+
+ rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)
+ out_block = lhs_blocks(idx, 0) @ rhs_block_lo
+ out_block += lhs_blocks(idx, 1) @ rhs_block_hi
+ for i in range(1, size):
+ rhs_block_lo = shift_up(rhs_block_lo)
+ rhs_block_hi = shift_dn(rhs_block_hi)
+ out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo
+ out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi
+ return out_block
+```
+
+```{code-cell}
+out = matmul_allgather_overlapped_bidi(lhs, rhs)
+print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
+```
+
+
+
+In practice, to reduce compile times we would probably roll this into a
+`jax.lax.fori_loop`. We might also have additional axes of parallelism
+involved.
+
+### Example 2: `psum_scatter` the result
+
+Another sharding we might start with has both `lhs` and `rhs` sharded along
+their contracting dimensions, with the output sharded like `rhs` again:
+
+```{code-cell}
+lhs_spec = P(None, 'i')
+lhs = device_put(lhs, lhs_spec)
+
+rhs_spec = P('i', None)
+rhs = device_put(rhs, rhs_spec)
+```
+
+Here we can use a `reduce_scatter` to perform the contraction sum over shards:
+
+```{code-cell}
+@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
+ out_specs=rhs_spec)
+def matmul_psumscatter(lhs_block, rhs_block):
+ out_summand = lhs_block @ rhs_block
+ return jax.lax.psum_scatter(out_summand, 'i', tiled=True)
+
+out = matmul_psumscatter(lhs, rhs)
+print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
+```
+
+But the scattering communication must wait for the entire local matrix multiply
+to finish before it can start. To get communication/computation overlap, we can
+inline an implementation of `psum_scatter` in terms of `ppermute`, then
+interleave the communication steps with local matrix multiplies:
+
+```{code-cell}
+@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
+ out_specs=rhs_spec)
+def matmul_psumscatter_overlapped(lhs_block, rhs_block):
+ size = jax.lax.psum(1, 'i')
+ idx = jax.lax.axis_index('i')
+ shift = partial(jax.lax.ppermute, axis_name='i',
+ perm=[(i, (i - 1) % size) for i in range(size)])
+ lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1]) # split 1st axis
+
+ out_summand = lhs_block[(idx + 1) % size] @ rhs_block
+ for i in range(1, size):
+ out_summand = shift(out_summand)
+ out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block
+ return out_summand
+```
+
+```{code-cell}
+out = matmul_psumscatter_overlapped(lhs, rhs)
+print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
+```
+
+As in the previous example, to fully utilize interconnects on TPU, we'd run a
+bidirectional version:
+
+```{code-cell}
+@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
+ out_specs=rhs_spec)
+def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):
+ size = jax.lax.psum(1, 'i')
+ idx = jax.lax.axis_index('i')
+ shift_up = partial(jax.lax.ppermute, axis_name='i',
+ perm=[(i, (i + 1) % size) for i in range(size)])
+ shift_dn = partial(jax.lax.ppermute, axis_name='i',
+ perm=[(i, (i - 1) % size) for i in range(size)])
+
+ B = lhs_block.shape[0] // size // 2 # half-size blocks
+ lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0)
+
+ out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block
+ out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block
+ for i in range(1, size):
+ out_summand_lo = shift_up(out_summand_lo)
+ out_summand_hi = shift_dn(out_summand_hi)
+ out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block
+ out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block
+ return jnp.concatenate([out_summand_lo, out_summand_hi])
+```
+
+```{code-cell}
+out = matmul_psumscatter_overlapped_bidi(lhs, rhs)
+print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
+```
+
+## Neural networks
+
+We can use `shard_map` to parallelize computation in neural networks, either by
+itself or in combination with the automatic partitioning in `jax.jit`. This
+section has a few examples based on this toy neural network and random data:
+
+```{code-cell}
+import jax
+import jax.numpy as jnp
+
+def predict(params, inputs):
+ for W, b in params:
+ outputs = jnp.dot(inputs, W) + b
+ inputs = jax.nn.relu(outputs)
+ return outputs
+
+def loss(params, batch):
+ inputs, targets = batch
+ predictions = predict(params, inputs)
+ return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
+```
+
+```{code-cell}
+def init_layer(key, n_in, n_out):
+ k1, k2 = jax.random.split(key)
+ W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
+ b = jax.random.normal(k2, (n_out,))
+ return W, b
+
+def init(key, layer_sizes, batch_size):
+ key, *keys = jax.random.split(key, len(layer_sizes))
+ params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
+
+ key, *keys = jax.random.split(key, 3)
+ inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
+ targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
+
+ return params, (inputs, targets)
+```
+
+```{code-cell}
+layer_sizes = [784, 128, 128, 128, 128, 128, 8]
+batch_size = 32
+
+params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)
+```
+
+Compare these examples with the purely [automatic partitioning examples in the
+"Distributed arrays and automatic partitioning"
+doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).
+While in those automatic partitioning examples we don't need to edit the model
+functions to use different parallelization strategies, with `shard_map` we
+often do.
+
+### 8-way batch data parallelism
+
+The simplest multi-device parallelism strategy is to shard the batch of inputs
+and targets over multiple devices, replicate the parameters over those devices,
+and apply the model in parallel to those shards of data. To evaluate the total
+loss, the devices need only communicate with a scalar-sized all-reduce-sum at
+the end. (To evaluate the gradient of the loss, the devices must perform
+all-reduce-sums of parameter gradients in the backward pass.)
+
+```{code-cell}
+from functools import partial
+
+from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
+from jax.experimental.shard_map import shard_map
+from jax.experimental import mesh_utils
+
+devices = mesh_utils.create_device_mesh((8,))
+
+# replicate initial params on all devices, shard data batch over devices
+mesh = Mesh(devices, ('batch',))
+batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
+params = jax.device_put(params, NamedSharding(mesh, P()))
+
+# adapt the loss function to sum the losses across devices
+def loss_dp(params, batch):
+ @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P(),
+ check_rep=False) # TODO remove check_rep=False
+ def loss_spmd(local_batch):
+ inputs, targets = local_batch
+ predictions = predict(params, inputs) # use reference 'predict`
+ local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
+ return jax.lax.pmean(local_loss, 'batch')
+ return loss_spmd(batch)
+```
+
+We can check that the loss and its gradients match the reference (base) model:
+
+```{code-cell}
+print(jax.jit(loss)(params, batch))
+print(jax.jit(loss_dp)(params, batch))
+```
+
+```{code-cell}
+def allclose(a, b):
+ return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
+
+print(allclose(jax.jit(jax.grad(loss))(params, batch),
+ jax.jit(jax.grad(loss_dp))(params, batch)))
+```
+
+We can print the compiler IR to inspect the gradient computation and verify
+that the collective all-reduce-sum operations happen where we'd expect: at the
+end of the forward pass to compute the loss value, and in the backward pass to
+compute the total parameter gradients.
+
+### 8-way fully sharded data parallelism (FSDP)
+
+Another strategy is to additionally shard the parameters over the devices,
+all-gathering each one when the full value is needed for the `jnp.dot` or bias
+addition. Since we only have one full parameter in local device memory at a
+time, rather than keeping all parameters in all device memories as in the
+preceding DP example, we free up significant memory that we can use for larger
+models or larger batch sizes. And because XLA will overlap computation and
+inter-device communication, the wall-clock time doesn't suffer.
+
+So now we need collectives in two places: the model prediction function
+`predict` needs to all-gather the parameters before they're used, and as in the
+DP case the loss function needs to sum the local losses to compute the total
+loss.
+
+There's one other ingredient we need: we don't want to store the fully gathered
+parameters from the forward pass for use on the backward pass. Instead, we want
+to gather them again on the backward pass. We can express that by using
+`jax.remat` with a [custom
+policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)
+(or a `custom_vjp`), though XLA typically does that rematerialization
+automatically.
+
+This general [FSDP
+approach](https://engineering.fb.com/2021/07/15/open-source/fsdp/) is similar
+to [weight update sharding (WUS)](https://arxiv.org/abs/2004.13336) and
+[ZeRO-3](https://arxiv.org/abs/1910.02054).
+
+```{code-cell}
+# shard data batch *and params* over devices
+mesh = Mesh(devices, ('batch',))
+batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
+params = jax.device_put(params, NamedSharding(mesh, P('batch')))
+
+# adapt the prediction function to gather weights just before their use,
+# and to re-gather them on the backward pass (rather than saving them)
+@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
+def predict_fsdp(params_frag, inputs):
+ for W_frag, b_frag in params_frag:
+ W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
+ b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
+ outputs = jnp.dot(inputs, W) + b
+ inputs = jax.nn.relu(outputs)
+ return outputs
+
+def loss_fsdp(params, batch):
+ @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())
+ def loss_spmd(local_params, local_batch):
+ inputs, targets = local_batch
+ predictions = predict_fsdp(local_params, inputs)
+ local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
+ return jax.lax.pmean(local_loss, 'batch')
+ return loss_spmd(params, batch)
+```
+
+Again we can check that the loss and its gradients match the reference model:
+
+```{code-cell}
+print(jax.jit(loss)(params, batch))
+print(jax.jit(loss_fsdp)(params, batch))
+
+print(allclose(jax.jit(jax.grad(loss))(params, batch),
+ jax.jit(jax.grad(loss_fsdp))(params, batch)))
+```
+
+### 8-way tensor parallelism (TP)
+
+Usually we don't use tensor model parallelism by itself, but seeing it in
+isolation is a good warmup on parallel matrix multiplication. It's also a good
+example of using `shard_map` in a library function, called in a larger
+`jit`-based computation.
+
+The parallelization idea is that we'll keep the data/activations sharded over
+its feature axis (rather than its batch axis), and we'll similarly shard weight
+matrices over their input-feature axis (and biases over their feature axis).
+Then to perform the parallel matrix multiplication, we'll perform local matrix
+multiplications followed by a `psum_scatter` to sum the local results and
+efficiently scatter the result's shards.
+
+```{code-cell}
+devices = mesh_utils.create_device_mesh((8,))
+mesh = Mesh(devices, ('feats',))
+
+batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
+params = jax.device_put(params, NamedSharding(mesh, P('feats')))
+
+def predict_tp(params, inputs):
+ for W, b in params:
+ outputs = gemm_tp(inputs, W, b)
+ inputs = jax.nn.relu(outputs)
+ return outputs
+
+@partial(shard_map, mesh=mesh,
+ in_specs=(P(None, 'feats'), P('feats', None), P('feats')),
+ out_specs=P(None, 'feats'))
+def gemm_tp(inputs, W, b):
+ block_result = jnp.dot(inputs, W)
+ return jax.lax.psum_scatter(block_result, 'feats',
+ scatter_dimension=1, tiled=True) + b
+
+def loss_tp(params, batch):
+ inputs, targets = batch
+ predictions = predict_tp(params, inputs)
+ return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum!
+```
+
+### FSDP + TP, with `shard_map` at the top level
+
+We can compose these strategies together, using multiple axes of parallelism.
+
+```{code-cell}
+devices = mesh_utils.create_device_mesh((4, 2))
+mesh = Mesh(devices, ('batch', 'feats'))
+
+batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))
+params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))
+
+# mostly same as previous predict_fsdp definition, except we call gemm_tp
+@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
+def predict_fsdp_tp(params_frag, inputs):
+ for W_frag, b_frag in params_frag:
+ W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
+ b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
+ block_result = jnp.dot(inputs, W)
+ outputs = jax.lax.psum_scatter(block_result, 'feats',
+ scatter_dimension=1, tiled=True) + b
+ inputs = jax.nn.relu(outputs)
+ return outputs
+
+@partial(shard_map, mesh=mesh,
+ in_specs=(P(('feats', 'batch')), P('batch', 'feats')))
+def loss_fsdp_tp(local_params, local_batch):
+ inputs, targets = local_batch
+ predictions = predict_fsdp_tp(local_params, inputs)
+ sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats')
+ return jax.lax.pmean(jnp.mean(sq_err), 'batch')
+```
+
+Notice how we have to do _two_ collective reductions: one over `'feats'` and
+one over `'batch'`. In the pure TP example, we didn't write the `'feats'`
+reduction explicitly because we only used `shard_map` within `gemm_tp`; in the
+caller `loss_tp`, the compiler automatically translated our use of `jnp.sum` to
+perform a `psum` as needed given the sharded result returned by `predict_tp`.
+
+```{code-cell}
+print(jax.jit(loss)(params, batch))
+print(jax.jit(loss_fsdp_tp)(params_, batch_))
+
+print(allclose(jax.jit(jax.grad(loss))(params, batch),
+ jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))
+```
+
+### SPMD pipeline parallelism (PP)
+
+With pipeline parallelism we aim to parallelize the evaluation of layers at
+different depths in our network. For example, one device might compute the
+application of the first layer while another device computes the application of
+the second; when they finish, the first device passes its results to the second
+while the second passes its results to the device responsible for the third
+layer, and the process repeats. In general the number of pipeline stages may be
+different from the number of layers, as each stage may be responsible for
+multiple layers.
+
+With SPMD pipelining, we exploit the fact that most layers in the network apply
+the computation, just with different parameter values. In particular, we can
+stack together all the parameters except for those for the first and last
+layers, then use a `shard_map` to map over blocks of those layer parameters,
+where each block of parameters corresponds to a pipeline stage. We then use the
+`jax.lax.ppermute` collective to shift data down the parallel pipeline.
+
+This particular pipelining strategy is essentially [the GPipe
+strategy](https://arxiv.org/abs/1811.06965). There are several variants, as
+well as quite different strategies, and which is appropriate can depend on the
+speed of the networking between stages and batch sizes. But for this tutorial
+we'll focus on just one strategy.
+
+First, we choose some pipeline parameters:
+
+```{code-cell}
+L = len(params) - 2 # num layers, excluding first and last
+N = batch_size # batch size
+F = params[0][0].shape[1] # num features
+
+# choose some pipeline parameters
+S = 2 # number of stages
+B = 8 # size of each microbatch
+assert L % S == 0, "S (number of stages) must divide L (number of inner layers)"
+
+# compute some useful quantities
+M, ragged = divmod(N, B) # M is number of microbatches
+assert not ragged, "B (size of each microbatch) must divide total batch size"
+K, ragged = divmod(M, S) # K is microbatches per stage
+assert not ragged, "S (number of stages) must divide number of microbatches"
+print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total')
+print(f'{B} examples per microbatch, {M} microbatches total')
+```
+
+```{code-cell}
+mesh = Mesh(jax.devices()[:S], ('stages',))
+
+def predict_pp(params, inputs):
+ (W_first, b_first), inner_params, (W_last, b_last) = params
+ inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first)
+ inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]),
+ inner_params, inputs)
+ outputs = jnp.dot(inputs, W_last) + b_last
+ return outputs
+
+@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),
+ out_specs=P())
+def loss_pp(params, batch):
+ inputs, targets = batch
+ predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)
+ local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
+ return jax.lax.pmean(local_loss, 'stages')
+```
+
+```{code-cell}
+def spmd_pipeline(fn, stage_params, inputs):
+ stage = jax.lax.axis_index('stages')
+ outputs = jnp.zeros_like(inputs) * jnp.nan
+ state = jnp.zeros((L // S, B, F)) * jnp.nan
+ for i in range(M+L-1):
+ state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0]))
+ state = jax.vmap(fn)(stage_params, state)
+ outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K]))
+ state, inputs, outputs = shift(i, state, inputs, outputs)
+ outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)])
+ return outputs
+
+def shift(i, state, inputs, outputs):
+ sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)])
+ state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1))
+ if (i % K) == (-1 % K):
+ inputs = sh(inputs, +1)
+ if ((i-L+1) % K) == (-1 % K):
+ outputs = sh(outputs, +1)
+ return state, inputs, outputs
+```
+
+```{code-cell}
+first_params, *inner_params, last_params = params
+Ws, bs = zip(*inner_params)
+params_stacked = jnp.stack(Ws), jnp.stack(bs)
+first_params = jax.device_put(first_params, NamedSharding(mesh, P()))
+params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages')))
+last_params = jax.device_put(last_params, NamedSharding(mesh, P()))
+params_ = first_params, params_stacked, last_params
+
+batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))
+```
+
+```{code-cell}
+print(jax.jit(loss)(params, batch))
+print(jax.jit(loss_pp)(params_, batch_))
+```
+
+```{code-cell}
+_ = jax.jit(jax.grad(loss_pp))(params_, batch_) # don't crash
+```