mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 05:06:07 +00:00
Add jax.copy_to_host_async(tree)
.
A relatively common pattern I've observed is the following: ```python _, metrics = some_jax_function() with profiler.Trace('compute_metrics'): jax.block_until_ready(metrics) with profiler.Trace('copy_to_host'): metrics = jax.device_get(metrics) ``` We are missing an opportunity here to more eagerly begin the h2d copy of the metrics (e.g. overlap it with closing the "compute_metrics" context manager etc. The intention of `jax.copy_to_host_async(x)` is to make it simple to begin h2d transfers as early as possible. Adapting the above code: ```python _, metrics = some_jax_function() # Begin D2H copies as early as we can. jax.copy_to_host_async(metrics) with profiler.Trace('compute_metrics'): jax.block_until_ready(metrics) with profiler.Trace('copy_to_host'): metrics = jax.device_get(metrics) ``` PiperOrigin-RevId: 731626446
This commit is contained in:
parent
2646b8d4ad
commit
1becb57ac9
@ -80,6 +80,7 @@ Just-in-time compilation (:code:`jit`)
|
|||||||
named_call
|
named_call
|
||||||
named_scope
|
named_scope
|
||||||
block_until_ready
|
block_until_ready
|
||||||
|
copy_to_host_async
|
||||||
make_mesh
|
make_mesh
|
||||||
|
|
||||||
.. _jax-grad:
|
.. _jax-grad:
|
||||||
|
@ -85,6 +85,7 @@ from jax._src.api import block_until_ready as block_until_ready
|
|||||||
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401
|
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401
|
||||||
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
|
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
|
||||||
from jax._src.api import clear_caches as clear_caches
|
from jax._src.api import clear_caches as clear_caches
|
||||||
|
from jax._src.api import copy_to_host_async as copy_to_host_async
|
||||||
from jax._src.custom_derivatives import closure_convert as closure_convert
|
from jax._src.custom_derivatives import closure_convert as closure_convert
|
||||||
from jax._src.custom_derivatives import custom_gradient as custom_gradient
|
from jax._src.custom_derivatives import custom_gradient as custom_gradient
|
||||||
from jax._src.custom_derivatives import custom_jvp as custom_jvp
|
from jax._src.custom_derivatives import custom_jvp as custom_jvp
|
||||||
|
@ -2828,6 +2828,31 @@ def block_until_ready(x):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def copy_to_host_async(x):
|
||||||
|
"""
|
||||||
|
Tries to call a ``copy_to_host_async`` method on pytree leaves.
|
||||||
|
|
||||||
|
For each leaf this method will try to call the ``copy_to_host_async`` method
|
||||||
|
on the leaf. If the leaf is not a JAX array, or if the leaf does not have a
|
||||||
|
``copy_to_host_async`` method, then this method will do nothing to the leaf.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: a pytree, usually with at least some JAX array instances at its leaves.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pytree with the same structure and values of the input, where the host
|
||||||
|
copy of the values of all JAX array leaves are started.
|
||||||
|
"""
|
||||||
|
for leaf in tree_leaves(x):
|
||||||
|
try:
|
||||||
|
copy_fn = leaf.copy_to_host_async
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
copy_fn()
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
def clear_backends():
|
def clear_backends():
|
||||||
"""
|
"""
|
||||||
Clear all backend clients so that new backend clients can be created later.
|
Clear all backend clients so that new backend clients can be created later.
|
||||||
|
@ -2581,6 +2581,30 @@ class APITest(jtu.JaxTestCase):
|
|||||||
self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False)
|
self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False)
|
||||||
self.assertEqual(pytree[3], 4)
|
self.assertEqual(pytree[3], 4)
|
||||||
|
|
||||||
|
def test_copy_to_host_async(self):
|
||||||
|
x = device_put(1.)
|
||||||
|
y = jax.copy_to_host_async(x)
|
||||||
|
# Tests mostly that copy_to_host_async() does not produce an error.
|
||||||
|
self.assertIs(y, x)
|
||||||
|
self.assertEqual(np.asarray(y), 1.)
|
||||||
|
|
||||||
|
def test_copy_to_host_async_non_array(self):
|
||||||
|
# Just tests that we don't error...
|
||||||
|
o = object()
|
||||||
|
mock_array = unittest.mock.Mock()
|
||||||
|
mock_array.copy_to_host_async.return_value = None
|
||||||
|
x = [o, 1, 2, 3, mock_array]
|
||||||
|
y = jax.copy_to_host_async(x)
|
||||||
|
self.assertIs(y, x)
|
||||||
|
self.assertEqual(y, [o, 1, 2, 3, mock_array])
|
||||||
|
mock_array.copy_to_host_async.assert_called_once()
|
||||||
|
|
||||||
|
def test_copy_to_host_async_does_not_hide_attribute_error(self):
|
||||||
|
x = unittest.mock.Mock()
|
||||||
|
x.copy_to_host_async.side_effect = AttributeError("foo")
|
||||||
|
with self.assertRaisesRegex(AttributeError, "foo"):
|
||||||
|
jax.copy_to_host_async(x)
|
||||||
|
|
||||||
@jtu.thread_unsafe_test() # Weakref destruction seems unpredictable with threads
|
@jtu.thread_unsafe_test() # Weakref destruction seems unpredictable with threads
|
||||||
def test_devicearray_weakref_friendly(self):
|
def test_devicearray_weakref_friendly(self):
|
||||||
x = device_put(1.)
|
x = device_put(1.)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user