mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
shard_map: add API docs
This commit is contained in:
parent
e4f3b3ff8f
commit
72e9eb9367
@ -26,6 +26,7 @@ Experimental Modules
|
||||
jax.experimental.compilation_cache
|
||||
jax.experimental.key_reuse
|
||||
jax.experimental.mesh_utils
|
||||
jax.experimental.shard_map
|
||||
|
||||
Experimental APIs
|
||||
-----------------
|
||||
|
12
docs/jax.experimental.shard_map.rst
Normal file
12
docs/jax.experimental.shard_map.rst
Normal file
@ -0,0 +1,12 @@
|
||||
``jax.experimental.shard_map`` module
|
||||
=====================================
|
||||
|
||||
.. automodule:: jax.experimental.shard_map
|
||||
|
||||
API
|
||||
---
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
shard_map
|
@ -20,7 +20,7 @@ import inspect
|
||||
import itertools as it
|
||||
from math import prod
|
||||
import operator as op
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -29,7 +29,6 @@ import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding, PartitionSpec, Mesh
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
from jax._src import array
|
||||
from jax._src import callback
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
@ -52,8 +51,7 @@ from jax._src.lax import (lax, parallel as lax_parallel, slicing,
|
||||
special, control_flow, ann)
|
||||
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
|
||||
as_hashable_function, memoize, partition_list,
|
||||
merge_lists, split_list, subs_list2,
|
||||
weakref_lru_cache)
|
||||
merge_lists, split_list, subs_list2)
|
||||
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -82,6 +80,56 @@ AxisName = Hashable
|
||||
@traceback_util.api_boundary
|
||||
def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
|
||||
check_rep: bool = True, auto: frozenset[AxisName] = frozenset()):
|
||||
"""Map a function over shards of data.
|
||||
|
||||
Note:
|
||||
``shard_map`` is an experimental API, and still subject to change. For an
|
||||
introduction to sharded data, refer to :ref:`sharded-computation`. For a more
|
||||
in-depth look at using ``shard_map``, refer to `SPMD multi-device parallelism with shard_map`_.
|
||||
|
||||
Args:
|
||||
f: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
|
||||
takes as input a shard of the mapped-over arguments and produces a shard
|
||||
of the output.
|
||||
mesh: a ``jax.sharding.Mesh`` representing the array of devices over which
|
||||
to shard the data and on which to execute instances of ``f``. The names of
|
||||
the ``Mesh`` can be used in collective communication operations in ``f``.
|
||||
This is typically created by a utility function like
|
||||
:func:`jax.experimental.mesh_utils.create_device_mesh`.
|
||||
in_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves,
|
||||
with a tree structure that is a tree prefix of the args tuple to be mapped
|
||||
over. Similar to :class:`~jax.sharding.NamedSharding`, each ``PartitionSpec``
|
||||
represents how the corresponding argument (or subtree of arguments) should
|
||||
be sharded along the named axes of ``mesh``. In each ``PartitionSpec``,
|
||||
mentioning a ``mesh`` axis name at a position expresses sharding the
|
||||
corresponding argument array axis along that positional axis; not
|
||||
mentioning an axis name expresses replication.
|
||||
out_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves,
|
||||
with a tree structure that is a tree prefix of the output of ``f``. Each
|
||||
``PartitionSpec`` represents how the corresponding output shards should be
|
||||
concatenated. In each ``PartitionSpec``, metioning a ``mesh`` axis name at
|
||||
a position expresses concatenation of that mesh axis's shards along the
|
||||
corresponding positional axis. Not mentioning a ``mesh`` axis name
|
||||
expresses a promise that the output values are equal along that mesh axis,
|
||||
and that rather than concatenating only a single value should be produced.
|
||||
check_rep: If True (default) enable additional validity checks and automatic
|
||||
differentiation optimizations. The validity checks concern whether any mesh
|
||||
axis names not mentioned in ``out_specs`` are consistent with how the outputs
|
||||
of ``f`` are replicated. Must be set False if using a Pallas kernel in ``f``.
|
||||
auto: (experimental) an optional set of axis names from ``mesh`` over which we
|
||||
do not shard the data or map the function, but rather we allow the
|
||||
compiler to control sharding. These names cannot be used in ``in_specs``,
|
||||
``out_specs``, or in communication collectives in ``f``.
|
||||
|
||||
Returns:
|
||||
A callable that applies the input function ``f`` across data sharded according to
|
||||
the ``mesh`` and ``in_specs``.
|
||||
|
||||
Examples:
|
||||
For examples, refer to :ref:`sharded-computation` or `SPMD multi-device parallelism with shard_map`_.
|
||||
|
||||
.. _SPMD multi-device parallelism with shard_map: https://jax.readthedocs.io/en/latest/notebooks/shard_map.html
|
||||
"""
|
||||
return _shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
|
||||
|
||||
def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,
|
||||
|
Loading…
x
Reference in New Issue
Block a user