Tom Hennigan 1becb57ac9 Add jax.copy_to_host_async(tree).
A relatively common pattern I've observed is the following:

```python
_, metrics = some_jax_function()

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

We are missing an opportunity here to more eagerly begin the h2d copy of
the metrics (e.g. overlap it with closing the "compute_metrics" context
manager etc. The intention of `jax.copy_to_host_async(x)` is to make it
simple to begin h2d transfers as early as possible. Adapting the above code:

```python
_, metrics = some_jax_function()

# Begin D2H copies as early as we can.
jax.copy_to_host_async(metrics)

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

PiperOrigin-RevId: 731626446
2025-02-27 01:22:15 -08:00
..
2024-12-06 10:59:11 -08:00
2024-08-28 11:04:14 -07:00
2025-02-18 17:01:29 +05:30
2025-02-10 09:32:07 -08:00
2024-11-07 21:47:50 -08:00
2025-02-03 19:09:00 -05:00
2025-02-03 19:09:00 -05:00
2021-03-08 10:44:52 -08:00
2024-11-15 13:44:31 -08:00
2024-11-15 13:44:31 -08:00
2023-12-06 09:16:43 -08:00
2023-12-06 09:16:43 -08:00
2024-12-29 13:06:19 +00:00
2024-07-15 17:30:54 +00:00
2025-02-27 01:22:15 -08:00
2023-10-10 08:46:36 -07:00
2023-01-30 11:57:06 -08:00
2024-12-13 15:29:56 +00:00
2025-01-02 09:49:31 +01:00
2024-09-25 21:26:05 +05:30
2019-10-09 17:45:09 +02:00
2024-09-05 20:47:40 +02:00

To rebuild the documentation, see Update Documentation.