Improve the sharding documentation.

* do some proofreading.
* add PmapSharding and GSPMDSharding, which are both missing.
This commit is contained in:
Peter Hawkins 2023-08-03 10:15:09 -04:00
parent a184b5e4af
commit d4336c1d8c
4 changed files with 70 additions and 48 deletions

View File

@ -13,10 +13,19 @@ Classes
.. autoclass:: XLACompatibleSharding
:members:
:show-inheritance:
.. autoclass:: SingleDeviceSharding
:members:
:show-inheritance:
.. autoclass:: NamedSharding
:members:
:show-inheritance:
.. autoclass:: SingleDeviceSharding
.. autoclass:: PositionalSharding
:members:
:show-inheritance:
.. autoclass:: PmapSharding
:members:
:show-inheritance:
.. autoclass:: GSPMDSharding
:members:
:show-inheritance:
.. autoclass:: PartitionSpec

View File

@ -25,13 +25,13 @@ _UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
class PartitionSpec(tuple):
"""Tuple describing how to partition tensor into mesh .
"""Tuple describing how to partition an array across a mesh of devices.
Each element is either None, string or a tuple of strings.
See``NamedSharding`` class for more details.
Each element is either ``None``, a string, or a tuple of strings.
See the documentation of :class:`jax.sharding.NamedSharding` for more details.
We create a separate class for this so JAX's pytree utilities can distinguish
it from a tuple that should be treated as a pytree.
This class exists so JAX's pytree utilities can distinguish a partition
specifications from tuples that should be treated as pytrees.
"""
# A sentinel value representing a dim is unconstrained.

View File

@ -38,14 +38,13 @@ def _addressable_devices_indices_map(
@util.use_cpp_class(xc.Sharding)
class Sharding:
"""Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out
across devices.
"""Describes how a :class:`jax.Array` is laid out across devices.
"""
# Abstract methods below that subclasses should implement.
@property
def device_set(self) -> set[Device]:
"""A ``set`` of global devices that this ``Sharding`` spans.
"""The set of devices that this :class:`Sharding` spans.
In multi-controller JAX, the set of devices is global, i.e., includes
non-addressable devices from other processes.
@ -54,9 +53,9 @@ class Sharding:
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Index | None]:
"""A global mapping from device to the slice of the global data it contains.
"""Returns a mapping from devices to the array slices each contains.
The devices in this mapping are global devices i.e. includes
The mapping includes all global devices, i.e., including
non-addressable devices from other processes.
"""
raise NotImplementedError('Subclasses should implement this method.')
@ -64,25 +63,29 @@ class Sharding:
def shard_shape(self, global_shape: Shape) -> Shape:
"""Returns the shape of the data on each device.
The shard shape returned by this function is calculated from the global
shape (it takes as an input) and the properties of the sharding.
The shard shape returned by this function is calculated from
``global_shape`` and the properties of the sharding.
"""
raise NotImplementedError('Subclasses should implement this method.')
def is_equivalent_to(self, other: Sharding, ndim: int) -> bool:
"""Returns True if two shardings put the same logical array
(sharded/unsharded) on the same device(s).
"""Returns ``True`` if two shardings are equivalent.
For example, every XLACompatibleSharding lowers to GSPMDSharding which
is a general representation. So `jax.sharding.NamedSharding` is equivalent
to `jax.sharding.PositionalSharding` if both of them lower to the same
GSPMDSharding.
Two shardings are equivalent if they place the same logical array shards on
the same devices.
For example, a :class:`NamedSharding` may be equivalent
to a :class:`PositionalSharding` if both place the same shards of the array
on the same devices.
"""
raise NotImplementedError('Subclasses should implement this method.')
@property
def is_fully_replicated(self) -> bool:
"""Returns if a sharding is fully replicated on all the devices."""
"""Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the
entire data."""
raise NotImplementedError('Subclasses should implement this method.')
@property
@ -95,7 +98,9 @@ class Sharding:
@functools.cached_property
def addressable_devices(self) -> set[Device]:
"""A set of devices that are addressable by the current process."""
"""The set of devices in the :class:`Sharding` that are addressable by the
current process.
"""
# Add a fast path for single controller runtimes.
if xb.process_count() == 1:
return self.device_set
@ -104,14 +109,17 @@ class Sharding:
@functools.cached_property
def is_fully_addressable(self) -> bool:
"""True if the current process can address all of the devices in device_set.
"""Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of
the devices named in the :class:`Sharding`.
"""
# The pytype disable is because pytype can't recognize a cached property.
return len(self.device_set) == len(self.addressable_devices) # type: ignore
def addressable_devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Index | None]:
"""A mapping from addressable device to the slice of global data it contains.
"""A mapping from addressable devices to the slice of array data each contains.
``addressable_devices_indices_map`` contains that part of
``device_indices_map`` that applies to the addressable devices.

View File

@ -22,7 +22,6 @@ import enum
import functools
import itertools
import math
import operator as op
from typing import Any, NamedTuple, Union, cast
from jax._src import mesh as mesh_lib
@ -51,10 +50,10 @@ XLADeviceAssignment = tuple[Device, ...]
# `_device_assignment` property and `_to_xla_hlo_sharding` method.
@use_cpp_class(xc.XLACompatibleSharding)
class XLACompatibleSharding(sharding.Sharding):
"""A `Sharding` that describes shardings expressible to XLA.
"""A :class:`Sharding` that describes shardings expressible to XLA.
Any ``Sharding`` that is a subclass of ``XLACompatibleSharding`` will work
with all JAX APIs and transformations that use XLA.
Subclasses of :class:`XLACompatibleSharding` work with
all JAX APIs and transformations that use XLA.
"""
# Abstract methods below that subclasses should implement.
@ -157,29 +156,30 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
@use_cpp_class(xc.NamedSharding)
class NamedSharding(XLACompatibleSharding):
r"""NamedSharding is a way to express ``Sharding``\s using named axes.
r"""A :class:`NamedSharding` expresses sharding using named axes.
``Mesh`` and ``PartitionSpec`` can be used to express a ``Sharding`` with a name.
A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
:class:`PartitionSpec` which describes how to shard an array across that
mesh.
``Mesh`` is a NumPy array of JAX devices in a multi-dimensional grid,
where each axis of the mesh has a name, e.g. 'x' or 'y'. Another name for
``Mesh`` is "logical mesh".
A :class:`Mesh` is a multidimensional NumPy array of JAX devices,
where each axis of the mesh has a name, e.g. ``'x'`` or ``'y'``.
``PartitionSpec`` is a tuple, whose elements can be a ``None``,
a mesh axis or a tuple of mesh axes. Each element describes how an input
A :class:`PartitionSpec` is a tuple, whose elements can be a ``None``,
a mesh axis, or a tuple of mesh axes. Each element describes how an input
dimension is partitioned across zero or more mesh dimensions. For example,
PartitionSpec('x', 'y') is a PartitionSpec where the first dimension of data
``PartitionSpec('x', 'y')`` says that the first dimension of data
is sharded across ``x`` axis of the mesh, and the second dimension is sharded
across ``y`` axis of the mesh.
The Distributed arrays and automatic parallelization
(https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names)
goes into more details and has diagrams to help explain the concept about
``Mesh`` and ``PartitionSpec``.
tutorial has more details and diagrams that explain how
:class:`Mesh` and :class:`PartitionSpec` are used.
Args:
mesh: A ``jax.sharding.Mesh`` object.
spec: A ``jax.sharding.PartitionSpec`` object.
mesh: A :class:`jax.sharding.Mesh` object.
spec: A :class:`jax.sharding.PartitionSpec` object.
Example:
@ -334,7 +334,7 @@ def get_replicated_hlo_sharding():
@use_cpp_class(xc.SingleDeviceSharding)
class SingleDeviceSharding(XLACompatibleSharding):
"""A subclass of ``XLACompatibleSharding`` that places its data on a single device.
"""A :class:`Sharding` that places its data on a single device.
Args:
device: A single :py:class:`Device`.
@ -398,6 +398,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
@use_cpp_class(xc.PmapSharding)
class PmapSharding(XLACompatibleSharding):
"""Describes a sharding used by :func:`jax.pmap`."""
devices: np.ndarray
sharding_spec: sharding_specs.ShardingSpec
@ -443,16 +444,15 @@ class PmapSharding(XLACompatibleSharding):
@classmethod
def default(cls, shape: Shape, sharded_dim: int = 0,
devices: Sequence[xc.Device] | None = None) -> PmapSharding:
"""Creates a `PmapSharding` which matches the implicit device order used by
`pmap` if devices is None. If devices is specified, it will use those
devices.
"""Creates a :class:`PmapSharding` which matches the default placement
used by :func:`jax.pmap`.
Args:
shape: The shape of the input array.
sharded_dim: Dimension the input array is sharded on. Defaults to 0.
devices: Optional sequence of devices used to create PmapSharding. If not
specified, it will use the implicit device order used by pmap which is
the order of jax.local_devices()
devices: Optional sequence of devices to use. If omitted, the implicit
device order used by pmap is used, which is the order of
:func:`jax.local_devices`.
"""
# The dtype doesn't matter here. Its only used for creating the
# sharding_spec.
@ -571,8 +571,13 @@ class PositionalSharding(XLACompatibleSharding):
# Will error if memory_kind does not exist on the device.
self._devices[0].memory(self._memory_kind)
shape = property(op.attrgetter('_ids.shape'))
ndim = property(op.attrgetter('_ids.ndim'))
@property
def shape(self):
return self._ids.shape
@property
def ndim(self):
return self._ids.ndim
def __repr__(self) -> str:
cls_name = self.__class__.__name__