[mesh_utils] Add device/slice count checks

PiperOrigin-RevId: 443178279
This commit is contained in:
James Bradbury 2022-04-20 13:29:33 -07:00 committed by jax authors
parent 455c9f823e
commit 38e754585f

View File

@ -250,6 +250,8 @@ def create_device_mesh(
"""
if devices is None:
devices = jax.devices()
if np.prod(mesh_shape) != len(devices):
raise ValueError('Number of devices must equal the product of mesh_shape')
device_kind = devices[-1].device_kind
if device_kind in (_TPU_V2, _TPU_V3):
if len(devices) == 8:
@ -316,6 +318,9 @@ def create_hybrid_device_mesh(mesh_shape: Sequence[int],
granule_id += 1
else:
break
if np.prod(dcn_mesh_shape) != len(granules):
raise ValueError(
'Number of slices must equal the product of dcn_mesh_shape')
per_granule_meshes = [create_device_mesh(mesh_shape, granule)
for granule in granules]
# TODO(jekbradbury): handle non-uniform DCN topologies