mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve the error message.
PiperOrigin-RevId: 525138471
This commit is contained in:
parent
f8fe5d0542
commit
6ca249da78
@ -187,7 +187,12 @@ def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray:
|
||||
else:
|
||||
out = np.empty(dims, dtype=object)
|
||||
for coords, d in zip(device_coords, jax_devices):
|
||||
assert d.core_on_chip == 0, d
|
||||
if d.core_on_chip != 0:
|
||||
raise AssertionError(
|
||||
'Creating meshes for TPU >v3 requires one device per chip.'
|
||||
f'Got device id {d.core_on_chip} for a device of kind {device_kind}'
|
||||
f': {d}'
|
||||
)
|
||||
out[coords[0], coords[1], coords[2]] = d
|
||||
return out
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user