mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve error message when trying to fetch value of non-addressable array.
PiperOrigin-RevId: 636642130
This commit is contained in:
parent
ee79d7d12b
commit
b527e1ec07
@ -599,10 +599,13 @@ class ArrayImpl(basearray.Array):
|
||||
# is_fully_addressable.
|
||||
if (not self.is_fully_addressable and
|
||||
not _process_has_full_value_in_mcjax(self.sharding, self.shape)):
|
||||
raise RuntimeError("Fetching value for `jax.Array` that spans "
|
||||
"non-addressable devices is not possible. You can use "
|
||||
"`jax.experimental.multihost_utils.process_allgather` "
|
||||
"for this use case.")
|
||||
raise RuntimeError(
|
||||
"Fetching value for `jax.Array` that spans non-addressable"
|
||||
" (non process local) devices is not possible. You can use"
|
||||
" `jax.experimental.multihost_utils.process_allgather` to print the"
|
||||
" global array or use `.addressable_shards` method of jax.Array to"
|
||||
" inspect the addressable (process local) shards."
|
||||
)
|
||||
|
||||
for i, _ in _cached_index_calc(self.sharding, self.shape):
|
||||
self._arrays[i]._copy_single_device_array_to_host_async()
|
||||
|
Loading…
x
Reference in New Issue
Block a user