mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add jax.copy_to_host_async(tree)
.
A relatively common pattern I've observed is the following: ```python _, metrics = some_jax_function() with profiler.Trace('compute_metrics'): jax.block_until_ready(metrics) with profiler.Trace('copy_to_host'): metrics = jax.device_get(metrics) ``` We are missing an opportunity here to more eagerly begin the h2d copy of the metrics (e.g. overlap it with closing the "compute_metrics" context manager etc. The intention of `jax.copy_to_host_async(x)` is to make it simple to begin h2d transfers as early as possible. Adapting the above code: ```python _, metrics = some_jax_function() # Begin D2H copies as early as we can. jax.copy_to_host_async(metrics) with profiler.Trace('compute_metrics'): jax.block_until_ready(metrics) with profiler.Trace('copy_to_host'): metrics = jax.device_get(metrics) ``` PiperOrigin-RevId: 731626446
This commit is contained in:
parent
2646b8d4ad
commit
1becb57ac9
@ -80,6 +80,7 @@ Just-in-time compilation (:code:`jit`)
|
||||
named_call
|
||||
named_scope
|
||||
block_until_ready
|
||||
copy_to_host_async
|
||||
make_mesh
|
||||
|
||||
.. _jax-grad:
|
||||
|
@ -85,6 +85,7 @@ from jax._src.api import block_until_ready as block_until_ready
|
||||
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401
|
||||
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
|
||||
from jax._src.api import clear_caches as clear_caches
|
||||
from jax._src.api import copy_to_host_async as copy_to_host_async
|
||||
from jax._src.custom_derivatives import closure_convert as closure_convert
|
||||
from jax._src.custom_derivatives import custom_gradient as custom_gradient
|
||||
from jax._src.custom_derivatives import custom_jvp as custom_jvp
|
||||
|
@ -2828,6 +2828,31 @@ def block_until_ready(x):
|
||||
|
||||
return x
|
||||
|
||||
def copy_to_host_async(x):
|
||||
"""
|
||||
Tries to call a ``copy_to_host_async`` method on pytree leaves.
|
||||
|
||||
For each leaf this method will try to call the ``copy_to_host_async`` method
|
||||
on the leaf. If the leaf is not a JAX array, or if the leaf does not have a
|
||||
``copy_to_host_async`` method, then this method will do nothing to the leaf.
|
||||
|
||||
Args:
|
||||
x: a pytree, usually with at least some JAX array instances at its leaves.
|
||||
|
||||
Returns:
|
||||
A pytree with the same structure and values of the input, where the host
|
||||
copy of the values of all JAX array leaves are started.
|
||||
"""
|
||||
for leaf in tree_leaves(x):
|
||||
try:
|
||||
copy_fn = leaf.copy_to_host_async
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
copy_fn()
|
||||
|
||||
return x
|
||||
|
||||
def clear_backends():
|
||||
"""
|
||||
Clear all backend clients so that new backend clients can be created later.
|
||||
|
@ -2581,6 +2581,30 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False)
|
||||
self.assertEqual(pytree[3], 4)
|
||||
|
||||
def test_copy_to_host_async(self):
|
||||
x = device_put(1.)
|
||||
y = jax.copy_to_host_async(x)
|
||||
# Tests mostly that copy_to_host_async() does not produce an error.
|
||||
self.assertIs(y, x)
|
||||
self.assertEqual(np.asarray(y), 1.)
|
||||
|
||||
def test_copy_to_host_async_non_array(self):
|
||||
# Just tests that we don't error...
|
||||
o = object()
|
||||
mock_array = unittest.mock.Mock()
|
||||
mock_array.copy_to_host_async.return_value = None
|
||||
x = [o, 1, 2, 3, mock_array]
|
||||
y = jax.copy_to_host_async(x)
|
||||
self.assertIs(y, x)
|
||||
self.assertEqual(y, [o, 1, 2, 3, mock_array])
|
||||
mock_array.copy_to_host_async.assert_called_once()
|
||||
|
||||
def test_copy_to_host_async_does_not_hide_attribute_error(self):
|
||||
x = unittest.mock.Mock()
|
||||
x.copy_to_host_async.side_effect = AttributeError("foo")
|
||||
with self.assertRaisesRegex(AttributeError, "foo"):
|
||||
jax.copy_to_host_async(x)
|
||||
|
||||
@jtu.thread_unsafe_test() # Weakref destruction seems unpredictable with threads
|
||||
def test_devicearray_weakref_friendly(self):
|
||||
x = device_put(1.)
|
||||
|
Loading…
x
Reference in New Issue
Block a user