mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Improve the sharding documentation.
* do some proofreading. * add PmapSharding and GSPMDSharding, which are both missing.
This commit is contained in:
parent
a184b5e4af
commit
d4336c1d8c
@ -13,10 +13,19 @@ Classes
|
||||
.. autoclass:: XLACompatibleSharding
|
||||
:members:
|
||||
:show-inheritance:
|
||||
.. autoclass:: SingleDeviceSharding
|
||||
:members:
|
||||
:show-inheritance:
|
||||
.. autoclass:: NamedSharding
|
||||
:members:
|
||||
:show-inheritance:
|
||||
.. autoclass:: SingleDeviceSharding
|
||||
.. autoclass:: PositionalSharding
|
||||
:members:
|
||||
:show-inheritance:
|
||||
.. autoclass:: PmapSharding
|
||||
:members:
|
||||
:show-inheritance:
|
||||
.. autoclass:: GSPMDSharding
|
||||
:members:
|
||||
:show-inheritance:
|
||||
.. autoclass:: PartitionSpec
|
||||
|
@ -25,13 +25,13 @@ _UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
|
||||
|
||||
|
||||
class PartitionSpec(tuple):
|
||||
"""Tuple describing how to partition tensor into mesh .
|
||||
"""Tuple describing how to partition an array across a mesh of devices.
|
||||
|
||||
Each element is either None, string or a tuple of strings.
|
||||
See``NamedSharding`` class for more details.
|
||||
Each element is either ``None``, a string, or a tuple of strings.
|
||||
See the documentation of :class:`jax.sharding.NamedSharding` for more details.
|
||||
|
||||
We create a separate class for this so JAX's pytree utilities can distinguish
|
||||
it from a tuple that should be treated as a pytree.
|
||||
This class exists so JAX's pytree utilities can distinguish a partition
|
||||
specifications from tuples that should be treated as pytrees.
|
||||
"""
|
||||
|
||||
# A sentinel value representing a dim is unconstrained.
|
||||
|
@ -38,14 +38,13 @@ def _addressable_devices_indices_map(
|
||||
|
||||
@util.use_cpp_class(xc.Sharding)
|
||||
class Sharding:
|
||||
"""Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out
|
||||
across devices.
|
||||
"""Describes how a :class:`jax.Array` is laid out across devices.
|
||||
"""
|
||||
|
||||
# Abstract methods below that subclasses should implement.
|
||||
@property
|
||||
def device_set(self) -> set[Device]:
|
||||
"""A ``set`` of global devices that this ``Sharding`` spans.
|
||||
"""The set of devices that this :class:`Sharding` spans.
|
||||
|
||||
In multi-controller JAX, the set of devices is global, i.e., includes
|
||||
non-addressable devices from other processes.
|
||||
@ -54,9 +53,9 @@ class Sharding:
|
||||
|
||||
def devices_indices_map(
|
||||
self, global_shape: Shape) -> Mapping[Device, Index | None]:
|
||||
"""A global mapping from device to the slice of the global data it contains.
|
||||
"""Returns a mapping from devices to the array slices each contains.
|
||||
|
||||
The devices in this mapping are global devices i.e. includes
|
||||
The mapping includes all global devices, i.e., including
|
||||
non-addressable devices from other processes.
|
||||
"""
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
@ -64,25 +63,29 @@ class Sharding:
|
||||
def shard_shape(self, global_shape: Shape) -> Shape:
|
||||
"""Returns the shape of the data on each device.
|
||||
|
||||
The shard shape returned by this function is calculated from the global
|
||||
shape (it takes as an input) and the properties of the sharding.
|
||||
The shard shape returned by this function is calculated from
|
||||
``global_shape`` and the properties of the sharding.
|
||||
"""
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
def is_equivalent_to(self, other: Sharding, ndim: int) -> bool:
|
||||
"""Returns True if two shardings put the same logical array
|
||||
(sharded/unsharded) on the same device(s).
|
||||
"""Returns ``True`` if two shardings are equivalent.
|
||||
|
||||
For example, every XLACompatibleSharding lowers to GSPMDSharding which
|
||||
is a general representation. So `jax.sharding.NamedSharding` is equivalent
|
||||
to `jax.sharding.PositionalSharding` if both of them lower to the same
|
||||
GSPMDSharding.
|
||||
Two shardings are equivalent if they place the same logical array shards on
|
||||
the same devices.
|
||||
|
||||
For example, a :class:`NamedSharding` may be equivalent
|
||||
to a :class:`PositionalSharding` if both place the same shards of the array
|
||||
on the same devices.
|
||||
"""
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
@property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
"""Returns if a sharding is fully replicated on all the devices."""
|
||||
"""Is this sharding fully replicated?
|
||||
|
||||
A sharding is fully replicated if each device has a complete copy of the
|
||||
entire data."""
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
@property
|
||||
@ -95,7 +98,9 @@ class Sharding:
|
||||
|
||||
@functools.cached_property
|
||||
def addressable_devices(self) -> set[Device]:
|
||||
"""A set of devices that are addressable by the current process."""
|
||||
"""The set of devices in the :class:`Sharding` that are addressable by the
|
||||
current process.
|
||||
"""
|
||||
# Add a fast path for single controller runtimes.
|
||||
if xb.process_count() == 1:
|
||||
return self.device_set
|
||||
@ -104,14 +109,17 @@ class Sharding:
|
||||
|
||||
@functools.cached_property
|
||||
def is_fully_addressable(self) -> bool:
|
||||
"""True if the current process can address all of the devices in device_set.
|
||||
"""Is this sharding fully addressable?
|
||||
|
||||
A sharding is fully addressable if the current process can address all of
|
||||
the devices named in the :class:`Sharding`.
|
||||
"""
|
||||
# The pytype disable is because pytype can't recognize a cached property.
|
||||
return len(self.device_set) == len(self.addressable_devices) # type: ignore
|
||||
|
||||
def addressable_devices_indices_map(
|
||||
self, global_shape: Shape) -> Mapping[Device, Index | None]:
|
||||
"""A mapping from addressable device to the slice of global data it contains.
|
||||
"""A mapping from addressable devices to the slice of array data each contains.
|
||||
|
||||
``addressable_devices_indices_map`` contains that part of
|
||||
``device_indices_map`` that applies to the addressable devices.
|
||||
|
@ -22,7 +22,6 @@ import enum
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import operator as op
|
||||
from typing import Any, NamedTuple, Union, cast
|
||||
|
||||
from jax._src import mesh as mesh_lib
|
||||
@ -51,10 +50,10 @@ XLADeviceAssignment = tuple[Device, ...]
|
||||
# `_device_assignment` property and `_to_xla_hlo_sharding` method.
|
||||
@use_cpp_class(xc.XLACompatibleSharding)
|
||||
class XLACompatibleSharding(sharding.Sharding):
|
||||
"""A `Sharding` that describes shardings expressible to XLA.
|
||||
"""A :class:`Sharding` that describes shardings expressible to XLA.
|
||||
|
||||
Any ``Sharding`` that is a subclass of ``XLACompatibleSharding`` will work
|
||||
with all JAX APIs and transformations that use XLA.
|
||||
Subclasses of :class:`XLACompatibleSharding` work with
|
||||
all JAX APIs and transformations that use XLA.
|
||||
"""
|
||||
|
||||
# Abstract methods below that subclasses should implement.
|
||||
@ -157,29 +156,30 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
|
||||
|
||||
@use_cpp_class(xc.NamedSharding)
|
||||
class NamedSharding(XLACompatibleSharding):
|
||||
r"""NamedSharding is a way to express ``Sharding``\s using named axes.
|
||||
r"""A :class:`NamedSharding` expresses sharding using named axes.
|
||||
|
||||
``Mesh`` and ``PartitionSpec`` can be used to express a ``Sharding`` with a name.
|
||||
A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
|
||||
:class:`PartitionSpec` which describes how to shard an array across that
|
||||
mesh.
|
||||
|
||||
``Mesh`` is a NumPy array of JAX devices in a multi-dimensional grid,
|
||||
where each axis of the mesh has a name, e.g. 'x' or 'y'. Another name for
|
||||
``Mesh`` is "logical mesh".
|
||||
A :class:`Mesh` is a multidimensional NumPy array of JAX devices,
|
||||
where each axis of the mesh has a name, e.g. ``'x'`` or ``'y'``.
|
||||
|
||||
``PartitionSpec`` is a tuple, whose elements can be a ``None``,
|
||||
a mesh axis or a tuple of mesh axes. Each element describes how an input
|
||||
A :class:`PartitionSpec` is a tuple, whose elements can be a ``None``,
|
||||
a mesh axis, or a tuple of mesh axes. Each element describes how an input
|
||||
dimension is partitioned across zero or more mesh dimensions. For example,
|
||||
PartitionSpec('x', 'y') is a PartitionSpec where the first dimension of data
|
||||
``PartitionSpec('x', 'y')`` says that the first dimension of data
|
||||
is sharded across ``x`` axis of the mesh, and the second dimension is sharded
|
||||
across ``y`` axis of the mesh.
|
||||
|
||||
The Distributed arrays and automatic parallelization
|
||||
(https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names)
|
||||
goes into more details and has diagrams to help explain the concept about
|
||||
``Mesh`` and ``PartitionSpec``.
|
||||
tutorial has more details and diagrams that explain how
|
||||
:class:`Mesh` and :class:`PartitionSpec` are used.
|
||||
|
||||
Args:
|
||||
mesh: A ``jax.sharding.Mesh`` object.
|
||||
spec: A ``jax.sharding.PartitionSpec`` object.
|
||||
mesh: A :class:`jax.sharding.Mesh` object.
|
||||
spec: A :class:`jax.sharding.PartitionSpec` object.
|
||||
|
||||
Example:
|
||||
|
||||
@ -334,7 +334,7 @@ def get_replicated_hlo_sharding():
|
||||
|
||||
@use_cpp_class(xc.SingleDeviceSharding)
|
||||
class SingleDeviceSharding(XLACompatibleSharding):
|
||||
"""A subclass of ``XLACompatibleSharding`` that places its data on a single device.
|
||||
"""A :class:`Sharding` that places its data on a single device.
|
||||
|
||||
Args:
|
||||
device: A single :py:class:`Device`.
|
||||
@ -398,6 +398,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
||||
|
||||
@use_cpp_class(xc.PmapSharding)
|
||||
class PmapSharding(XLACompatibleSharding):
|
||||
"""Describes a sharding used by :func:`jax.pmap`."""
|
||||
devices: np.ndarray
|
||||
sharding_spec: sharding_specs.ShardingSpec
|
||||
|
||||
@ -443,16 +444,15 @@ class PmapSharding(XLACompatibleSharding):
|
||||
@classmethod
|
||||
def default(cls, shape: Shape, sharded_dim: int = 0,
|
||||
devices: Sequence[xc.Device] | None = None) -> PmapSharding:
|
||||
"""Creates a `PmapSharding` which matches the implicit device order used by
|
||||
`pmap` if devices is None. If devices is specified, it will use those
|
||||
devices.
|
||||
"""Creates a :class:`PmapSharding` which matches the default placement
|
||||
used by :func:`jax.pmap`.
|
||||
|
||||
Args:
|
||||
shape: The shape of the input array.
|
||||
sharded_dim: Dimension the input array is sharded on. Defaults to 0.
|
||||
devices: Optional sequence of devices used to create PmapSharding. If not
|
||||
specified, it will use the implicit device order used by pmap which is
|
||||
the order of jax.local_devices()
|
||||
devices: Optional sequence of devices to use. If omitted, the implicit
|
||||
device order used by pmap is used, which is the order of
|
||||
:func:`jax.local_devices`.
|
||||
"""
|
||||
# The dtype doesn't matter here. Its only used for creating the
|
||||
# sharding_spec.
|
||||
@ -571,8 +571,13 @@ class PositionalSharding(XLACompatibleSharding):
|
||||
# Will error if memory_kind does not exist on the device.
|
||||
self._devices[0].memory(self._memory_kind)
|
||||
|
||||
shape = property(op.attrgetter('_ids.shape'))
|
||||
ndim = property(op.attrgetter('_ids.ndim'))
|
||||
@property
|
||||
def shape(self):
|
||||
return self._ids.shape
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return self._ids.ndim
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cls_name = self.__class__.__name__
|
||||
|
Loading…
x
Reference in New Issue
Block a user