Add docstrings for Sharding classes. Right now I am only documenting Sharding, XLACompatibleSharding, MeshPspecSharding and SingleDeviceSharding.

Also moving jax_array_migration guide to reference documentation.

PiperOrigin-RevId: 488489503
This commit is contained in:
Yash Katariya 2022-11-14 15:47:06 -08:00 committed by jax authors
parent c42bad85ef
commit e6c4d4a30e
5 changed files with 97 additions and 4 deletions

View File

@ -34,6 +34,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
jaxpr
notebooks/convolutions
pytrees
jax_array_migration
type_promotion
errors
transfer_guard
@ -76,7 +77,6 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
:caption: Notes
api_compatibility
jax_array_migration
deprecation
concurrency
gpu_memory_allocation

View File

@ -11,6 +11,7 @@ Subpackages
jax.numpy
jax.scipy
jax.sharding
jax.config
jax.debug
jax.dlpack

21
docs/jax.sharding.rst Normal file
View File

@ -0,0 +1,21 @@
jax.sharding package
====================
.. automodule:: jax.sharding
Classes
-------
.. currentmodule:: jax.sharding
.. autoclass:: Sharding
:members:
.. autoclass:: XLACompatibleSharding
:members:
:show-inheritance:
.. autoclass:: NamedSharding
:members:
:show-inheritance:
.. autoclass:: SingleDeviceSharding
:members:
:show-inheritance:

View File

@ -35,24 +35,38 @@ XLADeviceAssignment = Sequence[Device]
@pxla.use_cpp_class(xc.Sharding if xc._version >= 94 else None)
class Sharding(metaclass=abc.ABCMeta):
"""Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out
across devices.
"""
# Abstract methods below that subclasses should implement.
@abc.abstractproperty
def device_set(self) -> Set[Device]:
"""A unique set of devices that this sharding represents.
"""A ``set`` of global devices that this ``Sharding`` spans.
Devices can be non-addressable too.
In multi-controller JAX, the set of devices is global, i.e., includes
non-addressable devices from other processes.
"""
raise NotImplementedError('Subclasses should implement this method.')
@abc.abstractmethod
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
"""A global mapping from device to the slice of the global data it contains.
The devices in this mapping are global devices i.e. includes
non-addressable devices from other processes.
"""
raise NotImplementedError('Subclasses should implement this method.')
@abc.abstractmethod
def shard_shape(self, global_shape: Shape) -> Shape:
"""Returns the shape of the data on each device.
The shard shape returned by this function is calculated from the global
shape (it takes as an input) and the properties of the sharding.
"""
raise NotImplementedError('Subclasses should implement this method.')
#############################################################################
@ -60,24 +74,38 @@ class Sharding(metaclass=abc.ABCMeta):
@pxla.maybe_cached_property
def addressable_devices(self) -> Set[Device]:
"""A set of addressable devices by the current process"""
"""A set of devices that are addressable by the current process."""
return {d for d in self.device_set
if d.process_index == d.client.process_index()}
@pxla.maybe_cached_property
def is_fully_addressable(self) -> bool:
"""True if the current process can address all of the devices in device_set.
"""
# The pytype disable is because pytype can't recognize a cached property.
return len(self.device_set) == len(self.addressable_devices) # type: ignore
@functools.lru_cache(maxsize=4096)
def addressable_devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
"""A mapping from addressable device to the slice of global data it contains.
``addressable_devices_indices_map`` contains that part of
``device_indices_map`` that applies to the addressable devices.
"""
return {d: ind for d, ind in self.devices_indices_map(global_shape).items()
if d.process_index == d.client.process_index()}
# Shardings that inherit from XLACompatibleSharding should implement the
# `_device_assignment` property and `_to_xla_op_sharding` method.
@pxla.use_cpp_class(xc.XLACompatibleSharding if xc._version >= 94 else None)
class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
"""A `Sharding` that describes shardings expressible to XLA.
Any ``Sharding`` that is a subclass of ``XLACompatibleSharding`` will work
with all JAX APIs and transformations that use XLA.
"""
# Abstract methods below that subclasses should implement.
@ -173,6 +201,37 @@ def _enable_cpp_named_sharding():
@pxla.use_cpp_class(_enable_cpp_named_sharding())
class NamedSharding(XLACompatibleSharding):
"""NamedSharding is a way to express ``Sharding``s using named axes.
``Mesh`` and ``PartitionSpec`` can be used to express a ``Sharding`` with a name.
``Mesh`` is a NumPy array of JAX devices in a multi-dimensional grid,
where each axis of the mesh has a name, e.g. 'x' or 'y'. Another name for
``Mesh`` is "logical mesh".
``PartitionSpec`` is a named tuple, whose elements can be a ``None``,
a mesh axis or a tuple of mesh axes. Each element describes how an input
dimension is partitioned across zero or more mesh dimensions. For example,
PartitionSpec('x', 'y') is a PartitionSpec where the first dimension of data
is sharded across ``x`` axis of the mesh, and the second dimension is sharded
across ``y`` axis of the mesh.
The pjit tutorial (https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#more-information-on-partitionspec)
goes into more details and has diagrams to help explain the concept about
``Mesh`` and ``PartitionSpec``.
Args:
mesh: A ``jax.experimental.maps.Mesh`` object.
spec: A ``jax.experimental.PartitionSpec`` object.
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
"""
@pxla.use_cpp_method
def __init__(
@ -278,6 +337,16 @@ def _get_replicated_op_sharding():
@pxla.use_cpp_class(xc.SingleDeviceSharding if xc._version >= 95 else None)
class SingleDeviceSharding(XLACompatibleSharding):
"""A subclass of ``XLACompatibleSharding`` that places its data on a single device.
Args:
device: A single :py:class:`Device`.
Example:
>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
... jax.devices()[0])
"""
@pxla.use_cpp_method
def __init__(self, device: Device):

View File

@ -3864,6 +3864,8 @@ def use_cpp_class(cpp_cls):
_original_func(attr), "_use_cpp"):
setattr(cpp_cls, attr_name, attr)
cpp_cls.__doc__ = cls.__doc__
return cpp_cls
return wrapper