mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[mesh_utils] Add device/slice count checks
PiperOrigin-RevId: 443178279
This commit is contained in:
parent
455c9f823e
commit
38e754585f
@ -250,6 +250,8 @@ def create_device_mesh(
|
||||
"""
|
||||
if devices is None:
|
||||
devices = jax.devices()
|
||||
if np.prod(mesh_shape) != len(devices):
|
||||
raise ValueError('Number of devices must equal the product of mesh_shape')
|
||||
device_kind = devices[-1].device_kind
|
||||
if device_kind in (_TPU_V2, _TPU_V3):
|
||||
if len(devices) == 8:
|
||||
@ -316,6 +318,9 @@ def create_hybrid_device_mesh(mesh_shape: Sequence[int],
|
||||
granule_id += 1
|
||||
else:
|
||||
break
|
||||
if np.prod(dcn_mesh_shape) != len(granules):
|
||||
raise ValueError(
|
||||
'Number of slices must equal the product of dcn_mesh_shape')
|
||||
per_granule_meshes = [create_device_mesh(mesh_shape, granule)
|
||||
for granule in granules]
|
||||
# TODO(jekbradbury): handle non-uniform DCN topologies
|
||||
|
Loading…
x
Reference in New Issue
Block a user