From 1becb57ac94567a21b176ccbdb4900445d28ec1d Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Thu, 27 Feb 2025 01:21:39 -0800 Subject: [PATCH] Add `jax.copy_to_host_async(tree)`. A relatively common pattern I've observed is the following: ```python _, metrics = some_jax_function() with profiler.Trace('compute_metrics'): jax.block_until_ready(metrics) with profiler.Trace('copy_to_host'): metrics = jax.device_get(metrics) ``` We are missing an opportunity here to more eagerly begin the h2d copy of the metrics (e.g. overlap it with closing the "compute_metrics" context manager etc. The intention of `jax.copy_to_host_async(x)` is to make it simple to begin h2d transfers as early as possible. Adapting the above code: ```python _, metrics = some_jax_function() # Begin D2H copies as early as we can. jax.copy_to_host_async(metrics) with profiler.Trace('compute_metrics'): jax.block_until_ready(metrics) with profiler.Trace('copy_to_host'): metrics = jax.device_get(metrics) ``` PiperOrigin-RevId: 731626446 --- docs/jax.rst | 1 + jax/__init__.py | 1 + jax/_src/api.py | 25 +++++++++++++++++++++++++ tests/api_test.py | 24 ++++++++++++++++++++++++ 4 files changed, 51 insertions(+) diff --git a/docs/jax.rst b/docs/jax.rst index 432240a1e..98cd464cd 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -80,6 +80,7 @@ Just-in-time compilation (:code:`jit`) named_call named_scope block_until_ready + copy_to_host_async make_mesh .. _jax-grad: diff --git a/jax/__init__.py b/jax/__init__.py index 64f76e84a..950c3ed4b 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -85,6 +85,7 @@ from jax._src.api import block_until_ready as block_until_ready from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401 from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies from jax._src.api import clear_caches as clear_caches +from jax._src.api import copy_to_host_async as copy_to_host_async from jax._src.custom_derivatives import closure_convert as closure_convert from jax._src.custom_derivatives import custom_gradient as custom_gradient from jax._src.custom_derivatives import custom_jvp as custom_jvp diff --git a/jax/_src/api.py b/jax/_src/api.py index e9a9723d9..a5cbf8702 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2828,6 +2828,31 @@ def block_until_ready(x): return x +def copy_to_host_async(x): + """ + Tries to call a ``copy_to_host_async`` method on pytree leaves. + + For each leaf this method will try to call the ``copy_to_host_async`` method + on the leaf. If the leaf is not a JAX array, or if the leaf does not have a + ``copy_to_host_async`` method, then this method will do nothing to the leaf. + + Args: + x: a pytree, usually with at least some JAX array instances at its leaves. + + Returns: + A pytree with the same structure and values of the input, where the host + copy of the values of all JAX array leaves are started. + """ + for leaf in tree_leaves(x): + try: + copy_fn = leaf.copy_to_host_async + except AttributeError: + pass + else: + copy_fn() + + return x + def clear_backends(): """ Clear all backend clients so that new backend clients can be created later. diff --git a/tests/api_test.py b/tests/api_test.py index f145157cd..e8fef8011 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2581,6 +2581,30 @@ class APITest(jtu.JaxTestCase): self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False) self.assertEqual(pytree[3], 4) + def test_copy_to_host_async(self): + x = device_put(1.) + y = jax.copy_to_host_async(x) + # Tests mostly that copy_to_host_async() does not produce an error. + self.assertIs(y, x) + self.assertEqual(np.asarray(y), 1.) + + def test_copy_to_host_async_non_array(self): + # Just tests that we don't error... + o = object() + mock_array = unittest.mock.Mock() + mock_array.copy_to_host_async.return_value = None + x = [o, 1, 2, 3, mock_array] + y = jax.copy_to_host_async(x) + self.assertIs(y, x) + self.assertEqual(y, [o, 1, 2, 3, mock_array]) + mock_array.copy_to_host_async.assert_called_once() + + def test_copy_to_host_async_does_not_hide_attribute_error(self): + x = unittest.mock.Mock() + x.copy_to_host_async.side_effect = AttributeError("foo") + with self.assertRaisesRegex(AttributeError, "foo"): + jax.copy_to_host_async(x) + @jtu.thread_unsafe_test() # Weakref destruction seems unpredictable with threads def test_devicearray_weakref_friendly(self): x = device_put(1.)