mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 06:36:07 +00:00
shard_map: add API docs
This commit is contained in:
parent
e4f3b3ff8f
commit
72e9eb9367
@ -26,6 +26,7 @@ Experimental Modules
|
|||||||
jax.experimental.compilation_cache
|
jax.experimental.compilation_cache
|
||||||
jax.experimental.key_reuse
|
jax.experimental.key_reuse
|
||||||
jax.experimental.mesh_utils
|
jax.experimental.mesh_utils
|
||||||
|
jax.experimental.shard_map
|
||||||
|
|
||||||
Experimental APIs
|
Experimental APIs
|
||||||
-----------------
|
-----------------
|
||||||
|
12
docs/jax.experimental.shard_map.rst
Normal file
12
docs/jax.experimental.shard_map.rst
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
``jax.experimental.shard_map`` module
|
||||||
|
=====================================
|
||||||
|
|
||||||
|
.. automodule:: jax.experimental.shard_map
|
||||||
|
|
||||||
|
API
|
||||||
|
---
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
shard_map
|
@ -20,7 +20,7 @@ import inspect
|
|||||||
import itertools as it
|
import itertools as it
|
||||||
from math import prod
|
from math import prod
|
||||||
import operator as op
|
import operator as op
|
||||||
from typing import Any, Callable, Optional, TypeVar, Union
|
from typing import Any, Callable, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -29,7 +29,6 @@ import jax.numpy as jnp
|
|||||||
from jax.sharding import NamedSharding, PartitionSpec, Mesh
|
from jax.sharding import NamedSharding, PartitionSpec, Mesh
|
||||||
from jax._src import ad_checkpoint
|
from jax._src import ad_checkpoint
|
||||||
from jax._src import ad_util
|
from jax._src import ad_util
|
||||||
from jax._src import array
|
|
||||||
from jax._src import callback
|
from jax._src import callback
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import custom_derivatives
|
from jax._src import custom_derivatives
|
||||||
@ -52,8 +51,7 @@ from jax._src.lax import (lax, parallel as lax_parallel, slicing,
|
|||||||
special, control_flow, ann)
|
special, control_flow, ann)
|
||||||
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
|
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
|
||||||
as_hashable_function, memoize, partition_list,
|
as_hashable_function, memoize, partition_list,
|
||||||
merge_lists, split_list, subs_list2,
|
merge_lists, split_list, subs_list2)
|
||||||
weakref_lru_cache)
|
|
||||||
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
|
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
|
||||||
from jax._src.interpreters import batching
|
from jax._src.interpreters import batching
|
||||||
from jax._src.interpreters import mlir
|
from jax._src.interpreters import mlir
|
||||||
@ -82,6 +80,56 @@ AxisName = Hashable
|
|||||||
@traceback_util.api_boundary
|
@traceback_util.api_boundary
|
||||||
def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
|
def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
|
||||||
check_rep: bool = True, auto: frozenset[AxisName] = frozenset()):
|
check_rep: bool = True, auto: frozenset[AxisName] = frozenset()):
|
||||||
|
"""Map a function over shards of data.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
``shard_map`` is an experimental API, and still subject to change. For an
|
||||||
|
introduction to sharded data, refer to :ref:`sharded-computation`. For a more
|
||||||
|
in-depth look at using ``shard_map``, refer to `SPMD multi-device parallelism with shard_map`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
|
||||||
|
takes as input a shard of the mapped-over arguments and produces a shard
|
||||||
|
of the output.
|
||||||
|
mesh: a ``jax.sharding.Mesh`` representing the array of devices over which
|
||||||
|
to shard the data and on which to execute instances of ``f``. The names of
|
||||||
|
the ``Mesh`` can be used in collective communication operations in ``f``.
|
||||||
|
This is typically created by a utility function like
|
||||||
|
:func:`jax.experimental.mesh_utils.create_device_mesh`.
|
||||||
|
in_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves,
|
||||||
|
with a tree structure that is a tree prefix of the args tuple to be mapped
|
||||||
|
over. Similar to :class:`~jax.sharding.NamedSharding`, each ``PartitionSpec``
|
||||||
|
represents how the corresponding argument (or subtree of arguments) should
|
||||||
|
be sharded along the named axes of ``mesh``. In each ``PartitionSpec``,
|
||||||
|
mentioning a ``mesh`` axis name at a position expresses sharding the
|
||||||
|
corresponding argument array axis along that positional axis; not
|
||||||
|
mentioning an axis name expresses replication.
|
||||||
|
out_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves,
|
||||||
|
with a tree structure that is a tree prefix of the output of ``f``. Each
|
||||||
|
``PartitionSpec`` represents how the corresponding output shards should be
|
||||||
|
concatenated. In each ``PartitionSpec``, metioning a ``mesh`` axis name at
|
||||||
|
a position expresses concatenation of that mesh axis's shards along the
|
||||||
|
corresponding positional axis. Not mentioning a ``mesh`` axis name
|
||||||
|
expresses a promise that the output values are equal along that mesh axis,
|
||||||
|
and that rather than concatenating only a single value should be produced.
|
||||||
|
check_rep: If True (default) enable additional validity checks and automatic
|
||||||
|
differentiation optimizations. The validity checks concern whether any mesh
|
||||||
|
axis names not mentioned in ``out_specs`` are consistent with how the outputs
|
||||||
|
of ``f`` are replicated. Must be set False if using a Pallas kernel in ``f``.
|
||||||
|
auto: (experimental) an optional set of axis names from ``mesh`` over which we
|
||||||
|
do not shard the data or map the function, but rather we allow the
|
||||||
|
compiler to control sharding. These names cannot be used in ``in_specs``,
|
||||||
|
``out_specs``, or in communication collectives in ``f``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A callable that applies the input function ``f`` across data sharded according to
|
||||||
|
the ``mesh`` and ``in_specs``.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
For examples, refer to :ref:`sharded-computation` or `SPMD multi-device parallelism with shard_map`_.
|
||||||
|
|
||||||
|
.. _SPMD multi-device parallelism with shard_map: https://jax.readthedocs.io/en/latest/notebooks/shard_map.html
|
||||||
|
"""
|
||||||
return _shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
|
return _shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
|
||||||
|
|
||||||
def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,
|
def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user