mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
commit
dce27bfca6
@ -1109,8 +1109,10 @@ def _device_get(x):
|
||||
|
||||
def device_get(x):
|
||||
for y in tree_leaves(x):
|
||||
if not isinstance(y, core.Tracer):
|
||||
try:
|
||||
y.copy_to_host_async()
|
||||
except AttributeError:
|
||||
pass
|
||||
return tree_map(_device_get, x)
|
||||
|
||||
|
||||
|
@ -184,7 +184,7 @@ class CoreTest(jtu.JaxTestCase):
|
||||
try:
|
||||
tree_multimap(f, xs, ys_bad)
|
||||
assert False
|
||||
except TypeError:
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
def test_print_jaxpr_compound(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user