Document ValueError raised from mesh util functions

PiperOrigin-RevId: 563158659
This commit is contained in:
jax authors 2023-09-06 11:09:46 -07:00
parent bb8d5a0121
commit a23bc36d9a

View File

@ -286,6 +286,10 @@ def create_device_mesh(
required when passing host local inputs to `pjit`. A ValueError will be
raised if this function can't produce a suitable mesh.
Raises:
ValueError: if the number of devices doesn't equal the product of
`mesh_shape`.
Returns:
A np.ndarray of JAX devices with mesh_shape as its shape that can be fed
into jax.sharding.Mesh with good collective performance.
@ -327,8 +331,8 @@ def create_hybrid_device_mesh(mesh_shape: Sequence[int],
mesh_shape: shape of the logical mesh for the faster/inner network, ordered
by increasing network intensity, e.g. [replica, data, mdl] where mdl has
the most network communication requirements.
dcn_mesh_shape: shape of the logical mesh for the slower/outer network,
in the same order as mesh_shape.
dcn_mesh_shape: shape of the logical mesh for the slower/outer network, in
the same order as mesh_shape.
devices: optionally, the devices to construct a mesh for. Defaults to
jax.devices().
process_is_granule: if True, this function will treat processes as the units
@ -336,6 +340,11 @@ def create_hybrid_device_mesh(mesh_shape: Sequence[int],
attributes on devices and use slices as the units. Enabling this is meant
as a fallback for platforms (e.g., GPU) that don't set slice_index.
Raises:
ValueError: if the number of slices to which the `devices` belong doesn't
equal the product of `dcn_mesh_shape`, or if the number of devices
belonging to any single slice does not equal the product of `mesh_shape`.
Returns:
A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape
that can be fed into jax.sharding.Mesh for hybrid parallelism.