From e6c4d4a30e5a1d780d1552122dda348dd381e847 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 14 Nov 2022 15:47:06 -0800 Subject: [PATCH] Add docstrings for `Sharding` classes. Right now I am only documenting `Sharding`, `XLACompatibleSharding`, `MeshPspecSharding` and `SingleDeviceSharding`. Also moving jax_array_migration guide to reference documentation. PiperOrigin-RevId: 488489503 --- docs/index.rst | 2 +- docs/jax.rst | 1 + docs/jax.sharding.rst | 21 +++++++++++ jax/_src/sharding.py | 75 ++++++++++++++++++++++++++++++++++++++-- jax/interpreters/pxla.py | 2 ++ 5 files changed, 97 insertions(+), 4 deletions(-) create mode 100644 docs/jax.sharding.rst diff --git a/docs/index.rst b/docs/index.rst index 965d247a2..24769c4d8 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -34,6 +34,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more. jaxpr notebooks/convolutions pytrees + jax_array_migration type_promotion errors transfer_guard @@ -76,7 +77,6 @@ parallelize, Just-In-Time compile to GPU/TPU, and more. :caption: Notes api_compatibility - jax_array_migration deprecation concurrency gpu_memory_allocation diff --git a/docs/jax.rst b/docs/jax.rst index 7396c439e..a309b3ca0 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -11,6 +11,7 @@ Subpackages jax.numpy jax.scipy + jax.sharding jax.config jax.debug jax.dlpack diff --git a/docs/jax.sharding.rst b/docs/jax.sharding.rst new file mode 100644 index 000000000..9bf66cd90 --- /dev/null +++ b/docs/jax.sharding.rst @@ -0,0 +1,21 @@ +jax.sharding package +==================== + +.. automodule:: jax.sharding + +Classes +------- + +.. currentmodule:: jax.sharding + +.. autoclass:: Sharding + :members: +.. autoclass:: XLACompatibleSharding + :members: + :show-inheritance: +.. autoclass:: NamedSharding + :members: + :show-inheritance: +.. autoclass:: SingleDeviceSharding + :members: + :show-inheritance: diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index 253009f3d..f04cf76cb 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -35,24 +35,38 @@ XLADeviceAssignment = Sequence[Device] @pxla.use_cpp_class(xc.Sharding if xc._version >= 94 else None) class Sharding(metaclass=abc.ABCMeta): + """Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out + across devices. + """ # Abstract methods below that subclasses should implement. @abc.abstractproperty def device_set(self) -> Set[Device]: - """A unique set of devices that this sharding represents. + """A ``set`` of global devices that this ``Sharding`` spans. - Devices can be non-addressable too. + In multi-controller JAX, the set of devices is global, i.e., includes + non-addressable devices from other processes. """ raise NotImplementedError('Subclasses should implement this method.') @abc.abstractmethod def devices_indices_map( self, global_shape: Shape) -> Mapping[Device, Optional[Index]]: + """A global mapping from device to the slice of the global data it contains. + + The devices in this mapping are global devices i.e. includes + non-addressable devices from other processes. + """ raise NotImplementedError('Subclasses should implement this method.') @abc.abstractmethod def shard_shape(self, global_shape: Shape) -> Shape: + """Returns the shape of the data on each device. + + The shard shape returned by this function is calculated from the global + shape (it takes as an input) and the properties of the sharding. + """ raise NotImplementedError('Subclasses should implement this method.') ############################################################################# @@ -60,24 +74,38 @@ class Sharding(metaclass=abc.ABCMeta): @pxla.maybe_cached_property def addressable_devices(self) -> Set[Device]: - """A set of addressable devices by the current process""" + """A set of devices that are addressable by the current process.""" return {d for d in self.device_set if d.process_index == d.client.process_index()} @pxla.maybe_cached_property def is_fully_addressable(self) -> bool: + """True if the current process can address all of the devices in device_set. + """ # The pytype disable is because pytype can't recognize a cached property. return len(self.device_set) == len(self.addressable_devices) # type: ignore @functools.lru_cache(maxsize=4096) def addressable_devices_indices_map( self, global_shape: Shape) -> Mapping[Device, Optional[Index]]: + """A mapping from addressable device to the slice of global data it contains. + + ``addressable_devices_indices_map`` contains that part of + ``device_indices_map`` that applies to the addressable devices. + """ return {d: ind for d, ind in self.devices_indices_map(global_shape).items() if d.process_index == d.client.process_index()} +# Shardings that inherit from XLACompatibleSharding should implement the +# `_device_assignment` property and `_to_xla_op_sharding` method. @pxla.use_cpp_class(xc.XLACompatibleSharding if xc._version >= 94 else None) class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta): + """A `Sharding` that describes shardings expressible to XLA. + + Any ``Sharding`` that is a subclass of ``XLACompatibleSharding`` will work + with all JAX APIs and transformations that use XLA. + """ # Abstract methods below that subclasses should implement. @@ -173,6 +201,37 @@ def _enable_cpp_named_sharding(): @pxla.use_cpp_class(_enable_cpp_named_sharding()) class NamedSharding(XLACompatibleSharding): + """NamedSharding is a way to express ``Sharding``s using named axes. + + ``Mesh`` and ``PartitionSpec`` can be used to express a ``Sharding`` with a name. + + ``Mesh`` is a NumPy array of JAX devices in a multi-dimensional grid, + where each axis of the mesh has a name, e.g. 'x' or 'y'. Another name for + ``Mesh`` is "logical mesh". + + ``PartitionSpec`` is a named tuple, whose elements can be a ``None``, + a mesh axis or a tuple of mesh axes. Each element describes how an input + dimension is partitioned across zero or more mesh dimensions. For example, + PartitionSpec('x', 'y') is a PartitionSpec where the first dimension of data + is sharded across ``x`` axis of the mesh, and the second dimension is sharded + across ``y`` axis of the mesh. + + The pjit tutorial (https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#more-information-on-partitionspec) + goes into more details and has diagrams to help explain the concept about + ``Mesh`` and ``PartitionSpec``. + + Args: + mesh: A ``jax.experimental.maps.Mesh`` object. + spec: A ``jax.experimental.PartitionSpec`` object. + + Example: + + >>> from jax.experimental.maps import Mesh + >>> from jax.experimental import PartitionSpec as P + >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) + >>> spec = P('x', 'y') + >>> named_sharding = jax.sharding.NamedSharding(mesh, spec) + """ @pxla.use_cpp_method def __init__( @@ -278,6 +337,16 @@ def _get_replicated_op_sharding(): @pxla.use_cpp_class(xc.SingleDeviceSharding if xc._version >= 95 else None) class SingleDeviceSharding(XLACompatibleSharding): + """A subclass of ``XLACompatibleSharding`` that places its data on a single device. + + Args: + device: A single :py:class:`Device`. + + Example: + + >>> single_device_sharding = jax.sharding.SingleDeviceSharding( + ... jax.devices()[0]) + """ @pxla.use_cpp_method def __init__(self, device: Device): diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index e5f6401c1..5a8c93b30 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -3864,6 +3864,8 @@ def use_cpp_class(cpp_cls): _original_func(attr), "_use_cpp"): setattr(cpp_cls, attr_name, attr) + cpp_cls.__doc__ = cls.__doc__ + return cpp_cls return wrapper