Move failing CPPJitTest test case to PythonJitTest (#4268)

This commit is contained in:
Skye Wanderman-Milne 2020-09-11 12:12:34 -07:00 committed by GitHub
parent 6dc161cf27
commit ee9dccf39e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -238,16 +238,6 @@ class CPPJitTest(jtu.JaxTestCase):
self.assertDeleted(c)
self.assertDeleted(d)
def test_jit_nested_donate_ignored(self):
jit_fun = self.jit(lambda x: self.jit(lambda y: y**2, donate_argnums=0)(x))
a = jax.device_put(jnp.array(1))
# NOTE(mattjj): stopped raising error here and instead just ignored
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
# jit_fun(a)
jit_fun(a) # doesn't crash
def test_jnp_array_copy(self):
# https://github.com/google/jax/issues/3412
@ -396,6 +386,16 @@ class PythonJitTest(CPPJitTest):
x = device_put(data, device=device)
np.testing.assert_array_equal(-data, f(x))
def test_jit_nested_donate_ignored(self):
jit_fun = self.jit(lambda x: self.jit(lambda y: y**2, donate_argnums=0)(x))
a = jax.device_put(jnp.array(1))
# NOTE(mattjj): stopped raising error here and instead just ignored
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
# jit_fun(a)
jit_fun(a) # doesn't crash
class APITest(jtu.JaxTestCase):