mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Document ValueError raised from mesh util functions
PiperOrigin-RevId: 563158659
This commit is contained in:
parent
bb8d5a0121
commit
a23bc36d9a
@ -286,6 +286,10 @@ def create_device_mesh(
|
||||
required when passing host local inputs to `pjit`. A ValueError will be
|
||||
raised if this function can't produce a suitable mesh.
|
||||
|
||||
Raises:
|
||||
ValueError: if the number of devices doesn't equal the product of
|
||||
`mesh_shape`.
|
||||
|
||||
Returns:
|
||||
A np.ndarray of JAX devices with mesh_shape as its shape that can be fed
|
||||
into jax.sharding.Mesh with good collective performance.
|
||||
@ -327,8 +331,8 @@ def create_hybrid_device_mesh(mesh_shape: Sequence[int],
|
||||
mesh_shape: shape of the logical mesh for the faster/inner network, ordered
|
||||
by increasing network intensity, e.g. [replica, data, mdl] where mdl has
|
||||
the most network communication requirements.
|
||||
dcn_mesh_shape: shape of the logical mesh for the slower/outer network,
|
||||
in the same order as mesh_shape.
|
||||
dcn_mesh_shape: shape of the logical mesh for the slower/outer network, in
|
||||
the same order as mesh_shape.
|
||||
devices: optionally, the devices to construct a mesh for. Defaults to
|
||||
jax.devices().
|
||||
process_is_granule: if True, this function will treat processes as the units
|
||||
@ -336,6 +340,11 @@ def create_hybrid_device_mesh(mesh_shape: Sequence[int],
|
||||
attributes on devices and use slices as the units. Enabling this is meant
|
||||
as a fallback for platforms (e.g., GPU) that don't set slice_index.
|
||||
|
||||
Raises:
|
||||
ValueError: if the number of slices to which the `devices` belong doesn't
|
||||
equal the product of `dcn_mesh_shape`, or if the number of devices
|
||||
belonging to any single slice does not equal the product of `mesh_shape`.
|
||||
|
||||
Returns:
|
||||
A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape
|
||||
that can be fed into jax.sharding.Mesh for hybrid parallelism.
|
||||
|
Loading…
x
Reference in New Issue
Block a user