Improve error message when trying to fetch value of non-addressable array.

PiperOrigin-RevId: 636642130
This commit is contained in:
Yash Katariya 2024-05-23 12:40:40 -07:00 committed by jax authors
parent ee79d7d12b
commit b527e1ec07

View File

@ -599,10 +599,13 @@ class ArrayImpl(basearray.Array):
# is_fully_addressable.
if (not self.is_fully_addressable and
not _process_has_full_value_in_mcjax(self.sharding, self.shape)):
raise RuntimeError("Fetching value for `jax.Array` that spans "
"non-addressable devices is not possible. You can use "
"`jax.experimental.multihost_utils.process_allgather` "
"for this use case.")
raise RuntimeError(
"Fetching value for `jax.Array` that spans non-addressable"
" (non process local) devices is not possible. You can use"
" `jax.experimental.multihost_utils.process_allgather` to print the"
" global array or use `.addressable_shards` method of jax.Array to"
" inspect the addressable (process local) shards."
)
for i, _ in _cached_index_calc(self.sharding, self.shape):
self._arrays[i]._copy_single_device_array_to_host_async()