mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve the docstring for jax.Array.copy_to_host_async
.
PiperOrigin-RevId: 672666190
This commit is contained in:
parent
d6c36255e8
commit
72c095261f
@ -124,7 +124,19 @@ class Array(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def copy_to_host_async(self):
|
||||
"""Copies jax.Array to host asynchronously."""
|
||||
"""Copies an ``Array`` to the host asynchronously.
|
||||
|
||||
For arrays that live an an accelerator, such as a GPU or a TPU, JAX may
|
||||
cache the value of the array on the host. Normally this happens
|
||||
behind the scenes when the value of an on-device array is requested by the
|
||||
user, but waiting to initiate a device-to-host copy until the value is
|
||||
requested requires that JAX block the caller while waiting for the copy to
|
||||
complete.
|
||||
|
||||
``copy_to_host_async`` requests that JAX populate its on-host cache of an
|
||||
array, but does not wait for the copy to complete. This may speed up a
|
||||
future on-host access to the array's contents.
|
||||
"""
|
||||
|
||||
|
||||
Array.__module__ = "jax"
|
||||
|
Loading…
x
Reference in New Issue
Block a user