mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Reverts 55394a0914dc0583427a4ceb73dac56348911d15
PiperOrigin-RevId: 616201321
This commit is contained in:
parent
94122f8117
commit
f569031456
@ -2958,7 +2958,29 @@ def block_until_ready(x):
|
||||
return x.block_until_ready()
|
||||
except AttributeError:
|
||||
return x
|
||||
return tree_map(try_to_block, x)
|
||||
|
||||
if xla_extension_version < 246:
|
||||
return tree_map(try_to_block, x)
|
||||
|
||||
arrays = []
|
||||
for leaf in tree_leaves(x):
|
||||
if isinstance(leaf, array.ArrayImpl):
|
||||
arrays.append(leaf)
|
||||
else:
|
||||
try_to_block(leaf)
|
||||
|
||||
if not arrays:
|
||||
# `arrays` will be empty if tree_leaves(x) is empty or all leaves are not
|
||||
# jax.Array.
|
||||
pass
|
||||
elif len(arrays) == 1:
|
||||
# Fast path for single array.
|
||||
try_to_block(arrays[0])
|
||||
else:
|
||||
# Optimized for multiple arrays.
|
||||
xc.batched_block_until_ready(arrays)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def clear_backends():
|
||||
|
@ -2374,6 +2374,20 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertAllClose(pytree[0], jnp.array(1.), check_dtypes=False)
|
||||
self.assertAllClose(pytree[1], np.ones(3), check_dtypes=False)
|
||||
|
||||
def test_block_until_ready_numpy_arrays(self):
|
||||
pytree = (np.ones(1), np.ones(2))
|
||||
pytree = jax.block_until_ready(pytree)
|
||||
self.assertAllClose(pytree[0], np.ones(1), check_dtypes=False)
|
||||
self.assertAllClose(pytree[1], np.ones(2), check_dtypes=False)
|
||||
|
||||
def test_block_until_ready_mixed(self):
|
||||
pytree = (device_put(1.), device_put(2.), np.ones(3), 4)
|
||||
pytree = jax.block_until_ready(pytree)
|
||||
self.assertAllClose(pytree[0], jnp.array(1.), check_dtypes=False)
|
||||
self.assertAllClose(pytree[1], jnp.array(2.), check_dtypes=False)
|
||||
self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False)
|
||||
self.assertEqual(pytree[3], 4)
|
||||
|
||||
def test_devicearray_weakref_friendly(self):
|
||||
x = device_put(1.)
|
||||
y = weakref.ref(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user