Add jax.copy_to_host_async(tree).

A relatively common pattern I've observed is the following:

```python
_, metrics = some_jax_function()

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

We are missing an opportunity here to more eagerly begin the h2d copy of
the metrics (e.g. overlap it with closing the "compute_metrics" context
manager etc. The intention of `jax.copy_to_host_async(x)` is to make it
simple to begin h2d transfers as early as possible. Adapting the above code:

```python
_, metrics = some_jax_function()

# Begin D2H copies as early as we can.
jax.copy_to_host_async(metrics)

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

PiperOrigin-RevId: 731626446
This commit is contained in:
Tom Hennigan 2025-02-27 01:21:39 -08:00 committed by jax authors
parent 2646b8d4ad
commit 1becb57ac9
4 changed files with 51 additions and 0 deletions

View File

@ -80,6 +80,7 @@ Just-in-time compilation (:code:`jit`)
named_call
named_scope
block_until_ready
copy_to_host_async
make_mesh
.. _jax-grad:

View File

@ -85,6 +85,7 @@ from jax._src.api import block_until_ready as block_until_ready
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
from jax._src.api import clear_caches as clear_caches
from jax._src.api import copy_to_host_async as copy_to_host_async
from jax._src.custom_derivatives import closure_convert as closure_convert
from jax._src.custom_derivatives import custom_gradient as custom_gradient
from jax._src.custom_derivatives import custom_jvp as custom_jvp

View File

@ -2828,6 +2828,31 @@ def block_until_ready(x):
return x
def copy_to_host_async(x):
"""
Tries to call a ``copy_to_host_async`` method on pytree leaves.
For each leaf this method will try to call the ``copy_to_host_async`` method
on the leaf. If the leaf is not a JAX array, or if the leaf does not have a
``copy_to_host_async`` method, then this method will do nothing to the leaf.
Args:
x: a pytree, usually with at least some JAX array instances at its leaves.
Returns:
A pytree with the same structure and values of the input, where the host
copy of the values of all JAX array leaves are started.
"""
for leaf in tree_leaves(x):
try:
copy_fn = leaf.copy_to_host_async
except AttributeError:
pass
else:
copy_fn()
return x
def clear_backends():
"""
Clear all backend clients so that new backend clients can be created later.

View File

@ -2581,6 +2581,30 @@ class APITest(jtu.JaxTestCase):
self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False)
self.assertEqual(pytree[3], 4)
def test_copy_to_host_async(self):
x = device_put(1.)
y = jax.copy_to_host_async(x)
# Tests mostly that copy_to_host_async() does not produce an error.
self.assertIs(y, x)
self.assertEqual(np.asarray(y), 1.)
def test_copy_to_host_async_non_array(self):
# Just tests that we don't error...
o = object()
mock_array = unittest.mock.Mock()
mock_array.copy_to_host_async.return_value = None
x = [o, 1, 2, 3, mock_array]
y = jax.copy_to_host_async(x)
self.assertIs(y, x)
self.assertEqual(y, [o, 1, 2, 3, mock_array])
mock_array.copy_to_host_async.assert_called_once()
def test_copy_to_host_async_does_not_hide_attribute_error(self):
x = unittest.mock.Mock()
x.copy_to_host_async.side_effect = AttributeError("foo")
with self.assertRaisesRegex(AttributeError, "foo"):
jax.copy_to_host_async(x)
@jtu.thread_unsafe_test() # Weakref destruction seems unpredictable with threads
def test_devicearray_weakref_friendly(self):
x = device_put(1.)