diff --git a/docs/jax.experimental.rst b/docs/jax.experimental.rst index 37cab679d..73cea4a40 100644 --- a/docs/jax.experimental.rst +++ b/docs/jax.experimental.rst @@ -26,6 +26,7 @@ Experimental Modules jax.experimental.compilation_cache jax.experimental.key_reuse jax.experimental.mesh_utils + jax.experimental.shard_map Experimental APIs ----------------- diff --git a/docs/jax.experimental.shard_map.rst b/docs/jax.experimental.shard_map.rst new file mode 100644 index 000000000..65be7f21b --- /dev/null +++ b/docs/jax.experimental.shard_map.rst @@ -0,0 +1,12 @@ +``jax.experimental.shard_map`` module +===================================== + +.. automodule:: jax.experimental.shard_map + +API +--- + +.. autosummary:: + :toctree: _autosummary + + shard_map diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index a88571a90..c975a313c 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -20,7 +20,7 @@ import inspect import itertools as it from math import prod import operator as op -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, TypeVar, Union import numpy as np @@ -29,7 +29,6 @@ import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec, Mesh from jax._src import ad_checkpoint from jax._src import ad_util -from jax._src import array from jax._src import callback from jax._src import core from jax._src import custom_derivatives @@ -52,8 +51,7 @@ from jax._src.lax import (lax, parallel as lax_parallel, slicing, special, control_flow, ann) from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3, as_hashable_function, memoize, partition_list, - merge_lists, split_list, subs_list2, - weakref_lru_cache) + merge_lists, split_list, subs_list2) from jax.api_util import flatten_fun_nokwargs, shaped_abstractify from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -82,6 +80,56 @@ AxisName = Hashable @traceback_util.api_boundary def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs, check_rep: bool = True, auto: frozenset[AxisName] = frozenset()): + """Map a function over shards of data. + + Note: + ``shard_map`` is an experimental API, and still subject to change. For an + introduction to sharded data, refer to :ref:`sharded-computation`. For a more + in-depth look at using ``shard_map``, refer to `SPMD multi-device parallelism with shard_map`_. + + Args: + f: callable to be mapped. Each application of ``f``, or "instance" of ``f``, + takes as input a shard of the mapped-over arguments and produces a shard + of the output. + mesh: a ``jax.sharding.Mesh`` representing the array of devices over which + to shard the data and on which to execute instances of ``f``. The names of + the ``Mesh`` can be used in collective communication operations in ``f``. + This is typically created by a utility function like + :func:`jax.experimental.mesh_utils.create_device_mesh`. + in_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves, + with a tree structure that is a tree prefix of the args tuple to be mapped + over. Similar to :class:`~jax.sharding.NamedSharding`, each ``PartitionSpec`` + represents how the corresponding argument (or subtree of arguments) should + be sharded along the named axes of ``mesh``. In each ``PartitionSpec``, + mentioning a ``mesh`` axis name at a position expresses sharding the + corresponding argument array axis along that positional axis; not + mentioning an axis name expresses replication. + out_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves, + with a tree structure that is a tree prefix of the output of ``f``. Each + ``PartitionSpec`` represents how the corresponding output shards should be + concatenated. In each ``PartitionSpec``, metioning a ``mesh`` axis name at + a position expresses concatenation of that mesh axis's shards along the + corresponding positional axis. Not mentioning a ``mesh`` axis name + expresses a promise that the output values are equal along that mesh axis, + and that rather than concatenating only a single value should be produced. + check_rep: If True (default) enable additional validity checks and automatic + differentiation optimizations. The validity checks concern whether any mesh + axis names not mentioned in ``out_specs`` are consistent with how the outputs + of ``f`` are replicated. Must be set False if using a Pallas kernel in ``f``. + auto: (experimental) an optional set of axis names from ``mesh`` over which we + do not shard the data or map the function, but rather we allow the + compiler to control sharding. These names cannot be used in ``in_specs``, + ``out_specs``, or in communication collectives in ``f``. + + Returns: + A callable that applies the input function ``f`` across data sharded according to + the ``mesh`` and ``in_specs``. + + Examples: + For examples, refer to :ref:`sharded-computation` or `SPMD multi-device parallelism with shard_map`_. + + .. _SPMD multi-device parallelism with shard_map: https://jax.readthedocs.io/en/latest/notebooks/shard_map.html + """ return _shard_map(f, mesh, in_specs, out_specs, check_rep, auto) def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,