shard_map: add API docs

This commit is contained in:
Jake VanderPlas 2024-05-13 13:04:15 -07:00
parent e4f3b3ff8f
commit 72e9eb9367
3 changed files with 65 additions and 4 deletions

View File

@ -26,6 +26,7 @@ Experimental Modules
jax.experimental.compilation_cache
jax.experimental.key_reuse
jax.experimental.mesh_utils
jax.experimental.shard_map
Experimental APIs
-----------------

View File

@ -0,0 +1,12 @@
``jax.experimental.shard_map`` module
=====================================
.. automodule:: jax.experimental.shard_map
API
---
.. autosummary::
:toctree: _autosummary
shard_map

View File

@ -20,7 +20,7 @@ import inspect
import itertools as it
from math import prod
import operator as op
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, Callable, TypeVar, Union
import numpy as np
@ -29,7 +29,6 @@ import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec, Mesh
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import array
from jax._src import callback
from jax._src import core
from jax._src import custom_derivatives
@ -52,8 +51,7 @@ from jax._src.lax import (lax, parallel as lax_parallel, slicing,
special, control_flow, ann)
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
as_hashable_function, memoize, partition_list,
merge_lists, split_list, subs_list2,
weakref_lru_cache)
merge_lists, split_list, subs_list2)
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -82,6 +80,56 @@ AxisName = Hashable
@traceback_util.api_boundary
def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
check_rep: bool = True, auto: frozenset[AxisName] = frozenset()):
"""Map a function over shards of data.
Note:
``shard_map`` is an experimental API, and still subject to change. For an
introduction to sharded data, refer to :ref:`sharded-computation`. For a more
in-depth look at using ``shard_map``, refer to `SPMD multi-device parallelism with shard_map`_.
Args:
f: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
takes as input a shard of the mapped-over arguments and produces a shard
of the output.
mesh: a ``jax.sharding.Mesh`` representing the array of devices over which
to shard the data and on which to execute instances of ``f``. The names of
the ``Mesh`` can be used in collective communication operations in ``f``.
This is typically created by a utility function like
:func:`jax.experimental.mesh_utils.create_device_mesh`.
in_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves,
with a tree structure that is a tree prefix of the args tuple to be mapped
over. Similar to :class:`~jax.sharding.NamedSharding`, each ``PartitionSpec``
represents how the corresponding argument (or subtree of arguments) should
be sharded along the named axes of ``mesh``. In each ``PartitionSpec``,
mentioning a ``mesh`` axis name at a position expresses sharding the
corresponding argument array axis along that positional axis; not
mentioning an axis name expresses replication.
out_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves,
with a tree structure that is a tree prefix of the output of ``f``. Each
``PartitionSpec`` represents how the corresponding output shards should be
concatenated. In each ``PartitionSpec``, metioning a ``mesh`` axis name at
a position expresses concatenation of that mesh axis's shards along the
corresponding positional axis. Not mentioning a ``mesh`` axis name
expresses a promise that the output values are equal along that mesh axis,
and that rather than concatenating only a single value should be produced.
check_rep: If True (default) enable additional validity checks and automatic
differentiation optimizations. The validity checks concern whether any mesh
axis names not mentioned in ``out_specs`` are consistent with how the outputs
of ``f`` are replicated. Must be set False if using a Pallas kernel in ``f``.
auto: (experimental) an optional set of axis names from ``mesh`` over which we
do not shard the data or map the function, but rather we allow the
compiler to control sharding. These names cannot be used in ``in_specs``,
``out_specs``, or in communication collectives in ``f``.
Returns:
A callable that applies the input function ``f`` across data sharded according to
the ``mesh`` and ``in_specs``.
Examples:
For examples, refer to :ref:`sharded-computation` or `SPMD multi-device parallelism with shard_map`_.
.. _SPMD multi-device parallelism with shard_map: https://jax.readthedocs.io/en/latest/notebooks/shard_map.html
"""
return _shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,