mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add docstrings for Sharding
classes. Right now I am only documenting Sharding
, XLACompatibleSharding
, MeshPspecSharding
and SingleDeviceSharding
.
Also moving jax_array_migration guide to reference documentation. PiperOrigin-RevId: 488489503
This commit is contained in:
parent
c42bad85ef
commit
e6c4d4a30e
@ -34,6 +34,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
|
||||
jaxpr
|
||||
notebooks/convolutions
|
||||
pytrees
|
||||
jax_array_migration
|
||||
type_promotion
|
||||
errors
|
||||
transfer_guard
|
||||
@ -76,7 +77,6 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
|
||||
:caption: Notes
|
||||
|
||||
api_compatibility
|
||||
jax_array_migration
|
||||
deprecation
|
||||
concurrency
|
||||
gpu_memory_allocation
|
||||
|
@ -11,6 +11,7 @@ Subpackages
|
||||
|
||||
jax.numpy
|
||||
jax.scipy
|
||||
jax.sharding
|
||||
jax.config
|
||||
jax.debug
|
||||
jax.dlpack
|
||||
|
21
docs/jax.sharding.rst
Normal file
21
docs/jax.sharding.rst
Normal file
@ -0,0 +1,21 @@
|
||||
jax.sharding package
|
||||
====================
|
||||
|
||||
.. automodule:: jax.sharding
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
||||
.. currentmodule:: jax.sharding
|
||||
|
||||
.. autoclass:: Sharding
|
||||
:members:
|
||||
.. autoclass:: XLACompatibleSharding
|
||||
:members:
|
||||
:show-inheritance:
|
||||
.. autoclass:: NamedSharding
|
||||
:members:
|
||||
:show-inheritance:
|
||||
.. autoclass:: SingleDeviceSharding
|
||||
:members:
|
||||
:show-inheritance:
|
@ -35,24 +35,38 @@ XLADeviceAssignment = Sequence[Device]
|
||||
|
||||
@pxla.use_cpp_class(xc.Sharding if xc._version >= 94 else None)
|
||||
class Sharding(metaclass=abc.ABCMeta):
|
||||
"""Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out
|
||||
across devices.
|
||||
"""
|
||||
|
||||
# Abstract methods below that subclasses should implement.
|
||||
|
||||
@abc.abstractproperty
|
||||
def device_set(self) -> Set[Device]:
|
||||
"""A unique set of devices that this sharding represents.
|
||||
"""A ``set`` of global devices that this ``Sharding`` spans.
|
||||
|
||||
Devices can be non-addressable too.
|
||||
In multi-controller JAX, the set of devices is global, i.e., includes
|
||||
non-addressable devices from other processes.
|
||||
"""
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
@abc.abstractmethod
|
||||
def devices_indices_map(
|
||||
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
|
||||
"""A global mapping from device to the slice of the global data it contains.
|
||||
|
||||
The devices in this mapping are global devices i.e. includes
|
||||
non-addressable devices from other processes.
|
||||
"""
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
@abc.abstractmethod
|
||||
def shard_shape(self, global_shape: Shape) -> Shape:
|
||||
"""Returns the shape of the data on each device.
|
||||
|
||||
The shard shape returned by this function is calculated from the global
|
||||
shape (it takes as an input) and the properties of the sharding.
|
||||
"""
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
#############################################################################
|
||||
@ -60,24 +74,38 @@ class Sharding(metaclass=abc.ABCMeta):
|
||||
|
||||
@pxla.maybe_cached_property
|
||||
def addressable_devices(self) -> Set[Device]:
|
||||
"""A set of addressable devices by the current process"""
|
||||
"""A set of devices that are addressable by the current process."""
|
||||
return {d for d in self.device_set
|
||||
if d.process_index == d.client.process_index()}
|
||||
|
||||
@pxla.maybe_cached_property
|
||||
def is_fully_addressable(self) -> bool:
|
||||
"""True if the current process can address all of the devices in device_set.
|
||||
"""
|
||||
# The pytype disable is because pytype can't recognize a cached property.
|
||||
return len(self.device_set) == len(self.addressable_devices) # type: ignore
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def addressable_devices_indices_map(
|
||||
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
|
||||
"""A mapping from addressable device to the slice of global data it contains.
|
||||
|
||||
``addressable_devices_indices_map`` contains that part of
|
||||
``device_indices_map`` that applies to the addressable devices.
|
||||
"""
|
||||
return {d: ind for d, ind in self.devices_indices_map(global_shape).items()
|
||||
if d.process_index == d.client.process_index()}
|
||||
|
||||
|
||||
# Shardings that inherit from XLACompatibleSharding should implement the
|
||||
# `_device_assignment` property and `_to_xla_op_sharding` method.
|
||||
@pxla.use_cpp_class(xc.XLACompatibleSharding if xc._version >= 94 else None)
|
||||
class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
|
||||
"""A `Sharding` that describes shardings expressible to XLA.
|
||||
|
||||
Any ``Sharding`` that is a subclass of ``XLACompatibleSharding`` will work
|
||||
with all JAX APIs and transformations that use XLA.
|
||||
"""
|
||||
|
||||
# Abstract methods below that subclasses should implement.
|
||||
|
||||
@ -173,6 +201,37 @@ def _enable_cpp_named_sharding():
|
||||
|
||||
@pxla.use_cpp_class(_enable_cpp_named_sharding())
|
||||
class NamedSharding(XLACompatibleSharding):
|
||||
"""NamedSharding is a way to express ``Sharding``s using named axes.
|
||||
|
||||
``Mesh`` and ``PartitionSpec`` can be used to express a ``Sharding`` with a name.
|
||||
|
||||
``Mesh`` is a NumPy array of JAX devices in a multi-dimensional grid,
|
||||
where each axis of the mesh has a name, e.g. 'x' or 'y'. Another name for
|
||||
``Mesh`` is "logical mesh".
|
||||
|
||||
``PartitionSpec`` is a named tuple, whose elements can be a ``None``,
|
||||
a mesh axis or a tuple of mesh axes. Each element describes how an input
|
||||
dimension is partitioned across zero or more mesh dimensions. For example,
|
||||
PartitionSpec('x', 'y') is a PartitionSpec where the first dimension of data
|
||||
is sharded across ``x`` axis of the mesh, and the second dimension is sharded
|
||||
across ``y`` axis of the mesh.
|
||||
|
||||
The pjit tutorial (https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#more-information-on-partitionspec)
|
||||
goes into more details and has diagrams to help explain the concept about
|
||||
``Mesh`` and ``PartitionSpec``.
|
||||
|
||||
Args:
|
||||
mesh: A ``jax.experimental.maps.Mesh`` object.
|
||||
spec: A ``jax.experimental.PartitionSpec`` object.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from jax.experimental.maps import Mesh
|
||||
>>> from jax.experimental import PartitionSpec as P
|
||||
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
|
||||
>>> spec = P('x', 'y')
|
||||
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
|
||||
"""
|
||||
|
||||
@pxla.use_cpp_method
|
||||
def __init__(
|
||||
@ -278,6 +337,16 @@ def _get_replicated_op_sharding():
|
||||
|
||||
@pxla.use_cpp_class(xc.SingleDeviceSharding if xc._version >= 95 else None)
|
||||
class SingleDeviceSharding(XLACompatibleSharding):
|
||||
"""A subclass of ``XLACompatibleSharding`` that places its data on a single device.
|
||||
|
||||
Args:
|
||||
device: A single :py:class:`Device`.
|
||||
|
||||
Example:
|
||||
|
||||
>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
|
||||
... jax.devices()[0])
|
||||
"""
|
||||
|
||||
@pxla.use_cpp_method
|
||||
def __init__(self, device: Device):
|
||||
|
@ -3864,6 +3864,8 @@ def use_cpp_class(cpp_cls):
|
||||
_original_func(attr), "_use_cpp"):
|
||||
setattr(cpp_cls, attr_name, attr)
|
||||
|
||||
cpp_cls.__doc__ = cls.__doc__
|
||||
|
||||
return cpp_cls
|
||||
|
||||
return wrapper
|
||||
|
Loading…
x
Reference in New Issue
Block a user