mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Add a docstring for maps.Mesh
This commit is contained in:
parent
b267fd4336
commit
9719cc89d3
@ -8,6 +8,6 @@ API
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
mesh
|
||||
|
||||
Mesh
|
||||
xmap
|
||||
|
@ -1832,6 +1832,49 @@ mlir.register_lowering(xla_pmap_p, _pmap_lowering)
|
||||
# ------------------- xmap -------------------
|
||||
|
||||
class Mesh(ContextDecorator):
|
||||
"""Declare the hardware resources available in the scope of this manager.
|
||||
|
||||
In particular, all ``axis_names`` become valid resource names inside the
|
||||
managed block and can be used e.g. in the ``axis_resources`` argument of
|
||||
:py:func:`xmap`.
|
||||
|
||||
If you are compiling in multiple threads, make sure that the
|
||||
``with Mesh`` context manager is inside the function that the threads will
|
||||
execute.
|
||||
|
||||
Args:
|
||||
devices: A NumPy ndarray object containing JAX device objects (as
|
||||
obtained e.g. from :py:func:`jax.devices`).
|
||||
axis_names: A sequence of resource axis names to be assigned to the
|
||||
dimensions of the ``devices`` argument. Its length should match the
|
||||
rank of ``devices``.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from jax.experimental.maps import Mesh
|
||||
>>> from jax.experimental.pjit import pjit
|
||||
>>> from jax.experimental import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> inp = np.arange(16).reshape((8, 2))
|
||||
>>> devices = np.array(jax.devices()).reshape(4, 2)
|
||||
...g
|
||||
>>> # Declare a 2D mesh with axes `x` and `y`.
|
||||
>>> global_mesh = Mesh(devices, ('x', 'y'))
|
||||
>>> # Use the mesh object directly as a context manager.
|
||||
>>> with global_mesh:
|
||||
... pjit(lambda x: x, in_axis_resources=None, out_axis_resources=None)(inp)
|
||||
|
||||
>>> # Initialize the Mesh and use the mesh as the context manager.
|
||||
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
|
||||
... pjit(lambda x: x, in_axis_resources=None, out_axis_resources=None)(inp)
|
||||
|
||||
>>> # Also you can use it as `with ... as ...`.
|
||||
>>> global_mesh = Mesh(devices, ('x', 'y'))
|
||||
>>> with global_mesh as m:
|
||||
... pjit(lambda x: x, in_axis_resources=None, out_axis_resources=None)(inp)
|
||||
"""
|
||||
|
||||
devices: np.ndarray
|
||||
axis_names: Tuple[MeshAxisName, ...]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user