Improve the error message.

PiperOrigin-RevId: 525138471
This commit is contained in:
Jean-Baptiste Lespiau 2023-04-18 07:15:12 -07:00 committed by jax authors
parent f8fe5d0542
commit 6ca249da78

View File

@ -187,7 +187,12 @@ def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray:
else:
out = np.empty(dims, dtype=object)
for coords, d in zip(device_coords, jax_devices):
assert d.core_on_chip == 0, d
if d.core_on_chip != 0:
raise AssertionError(
'Creating meshes for TPU >v3 requires one device per chip.'
f'Got device id {d.core_on_chip} for a device of kind {device_kind}'
f': {d}'
)
out[coords[0], coords[1], coords[2]] = d
return out