Implement more efficient jax.block_until_ready(x) in C++

The current implementation synchronously calls `ArrayImpl.block_until_ready()` one by one. This is suboptimal when it's not cheap to query the readiness of an array. Also, calling `x.block_until_ready()` causes GIL to be acquired/released repeatedly.

To address this issue, this CL introduces a C++ implementation of `jax.block_until_ready(x)` that uses IFRT's `Array::GetReadyFuture()` to asynchronously query the readiness of all arrays and wait for them once. To preserve the previous behavior, the C++ implementation also has a slow path for any non-PyArray objects that implement `block_until_ready`.

PiperOrigin-RevId: 581302290
This commit is contained in:
Junwhan Ahn 2023-11-10 10:32:45 -08:00 committed by jax authors
parent 2aaa7559f9
commit 6cc6d09364
2 changed files with 29 additions and 1 deletions

View File

@ -2940,7 +2940,29 @@ def block_until_ready(x):
return x.block_until_ready()
except AttributeError:
return x
return tree_map(try_to_block, x)
if xla_extension_version < 214:
return tree_map(try_to_block, x)
arrays = []
for leaf in tree_leaves(x):
if isinstance(leaf, array.ArrayImpl):
arrays.append(leaf)
else:
try_to_block(leaf)
if not arrays:
# `arrays` will be empty if tree_leaves(x) is empty or all leaves are not
# jax.Array.
pass
elif len(arrays) == 1:
# Fast path for single array.
try_to_block(arrays[0])
else:
# Optimized for multiple arrays.
xc.batched_block_until_ready(arrays)
return x
def clear_backends():

View File

@ -2409,6 +2409,12 @@ class APITest(jtu.JaxTestCase):
self.assertAllClose(pytree[0], jnp.array(1.), check_dtypes=False)
self.assertAllClose(pytree[1], np.ones(3), check_dtypes=False)
def test_block_until_ready_numpy_arrays(self):
pytree = (np.ones(1), np.ones(2))
pytree = jax.block_until_ready(pytree)
self.assertAllClose(pytree[0], np.ones(1), check_dtypes=False)
self.assertAllClose(pytree[1], np.ones(2), check_dtypes=False)
def test_devicearray_weakref_friendly(self):
x = device_put(1.)
y = weakref.ref(x)