skip custom_jvp/vjp tests which dont work with initial-style staging

These tests, involving nondiff_argnums and/or closing over tracers, happen to
work with final-style JIT but not our initial-style primitives. We shouldn't
support this behavior anyway; there are good alternatives.
This commit is contained in:
Matthew Johnson 2023-02-01 20:17:28 -08:00
parent 4d56def91f
commit cd615b6be8

View File

@ -6511,6 +6511,16 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def test_nondiff_arg_jit_tracer(self):
# This test would pass with "final-style" JIT tracing, but that was
# misleading: it doesn't work with "initial-style" staging, i.e. control
# flow primitives like jax.lax.scan or even pjit. The behavior isn't very
# useful either: instead of using nondiff_argnums here, a user can just pass
# such inputs as ordinary arguments, and ignore the corresponding tangents.
# Then nondiff_argnums can be reserved for (1) non jaxtype data (like a
# string- or callable-valued argument which parameterizes the function or
# rule) or (2) static data (e.g. integers which parameterize shapes).
raise unittest.SkipTest("behavior no longer supported")
@partial(api.custom_jvp, nondiff_argnums=(0,))
def f(x, y):
return x * y
@ -6527,6 +6537,22 @@ class CustomJVPTest(jtu.JaxTestCase):
expected = (6., 5.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_nondiff_arg_vmap_tracer(self):
@partial(api.custom_jvp, nondiff_argnums=(0,))
def f(x, y):
return x * y
def f_jvp(x, primals, tangents):
(y,), (t_y,) = primals, tangents
return f(x, y), 5 * t_y
f.defjvp(f_jvp)
g = jax.vmap(f)
ans = api.jvp(lambda y: g(jnp.array([2.]), y),
(jnp.array([3.]),), (jnp.array([1.]),))
expected = (jnp.array([6.]), jnp.array([5.]))
self.assertAllClose(ans, expected, check_dtypes=False)
def test_nondiff_arg_hiding_jvp_tracer(self):
def f(x):
@partial(api.custom_jvp, nondiff_argnums=(0,))
@ -7526,7 +7552,10 @@ class CustomVJPTest(jtu.JaxTestCase):
expected = (2., jnp.cos(1.))
self.assertAllClose(ans, expected, check_dtypes=False)
def test_closed_over_tracer(self):
def test_closed_over_jit_tracer(self):
# See the comment in CustomJVPTest.test_nondiff_arg_jit_tracer.
raise unittest.SkipTest("behavior no longer supported")
# This test is similar to test_nondiff_arg_tracer except it uses lexical
# closure rather than the nondiff_argnums mechanism. We decided to disallow
# tracers in nondiff_argnums to greatly simplify bookkeeping while still
@ -7554,7 +7583,7 @@ class CustomVJPTest(jtu.JaxTestCase):
expected = jnp.cos(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_closed_over_tracer2(self):
def test_closed_over_vmap_tracer(self):
def outer(x):
@api.custom_vjp
def f(y):