Improve the docstring for jax.Array.copy_to_host_async.

PiperOrigin-RevId: 672666190
This commit is contained in:
Peter Hawkins 2024-09-09 14:03:19 -07:00 committed by jax authors
parent d6c36255e8
commit 72c095261f

View File

@ -124,7 +124,19 @@ class Array(abc.ABC):
@abc.abstractmethod
def copy_to_host_async(self):
"""Copies jax.Array to host asynchronously."""
"""Copies an ``Array`` to the host asynchronously.
For arrays that live an an accelerator, such as a GPU or a TPU, JAX may
cache the value of the array on the host. Normally this happens
behind the scenes when the value of an on-device array is requested by the
user, but waiting to initiate a device-to-host copy until the value is
requested requires that JAX block the caller while waiting for the copy to
complete.
``copy_to_host_async`` requests that JAX populate its on-host cache of an
array, but does not wait for the copy to complete. This may speed up a
future on-host access to the array's contents.
"""
Array.__module__ = "jax"