mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Move failing CPPJitTest test case to PythonJitTest (#4268)
This commit is contained in:
parent
6dc161cf27
commit
ee9dccf39e
@ -238,16 +238,6 @@ class CPPJitTest(jtu.JaxTestCase):
|
||||
self.assertDeleted(c)
|
||||
self.assertDeleted(d)
|
||||
|
||||
def test_jit_nested_donate_ignored(self):
|
||||
jit_fun = self.jit(lambda x: self.jit(lambda y: y**2, donate_argnums=0)(x))
|
||||
a = jax.device_put(jnp.array(1))
|
||||
|
||||
# NOTE(mattjj): stopped raising error here and instead just ignored
|
||||
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
|
||||
# jit_fun(a)
|
||||
|
||||
jit_fun(a) # doesn't crash
|
||||
|
||||
def test_jnp_array_copy(self):
|
||||
# https://github.com/google/jax/issues/3412
|
||||
|
||||
@ -396,6 +386,16 @@ class PythonJitTest(CPPJitTest):
|
||||
x = device_put(data, device=device)
|
||||
np.testing.assert_array_equal(-data, f(x))
|
||||
|
||||
def test_jit_nested_donate_ignored(self):
|
||||
jit_fun = self.jit(lambda x: self.jit(lambda y: y**2, donate_argnums=0)(x))
|
||||
a = jax.device_put(jnp.array(1))
|
||||
|
||||
# NOTE(mattjj): stopped raising error here and instead just ignored
|
||||
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
|
||||
# jit_fun(a)
|
||||
|
||||
jit_fun(a) # doesn't crash
|
||||
|
||||
|
||||
class APITest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user