mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
skip custom_jvp/vjp tests which dont work with initial-style staging
These tests, involving nondiff_argnums and/or closing over tracers, happen to work with final-style JIT but not our initial-style primitives. We shouldn't support this behavior anyway; there are good alternatives.
This commit is contained in:
parent
4d56def91f
commit
cd615b6be8
@ -6511,6 +6511,16 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_nondiff_arg_jit_tracer(self):
|
||||
# This test would pass with "final-style" JIT tracing, but that was
|
||||
# misleading: it doesn't work with "initial-style" staging, i.e. control
|
||||
# flow primitives like jax.lax.scan or even pjit. The behavior isn't very
|
||||
# useful either: instead of using nondiff_argnums here, a user can just pass
|
||||
# such inputs as ordinary arguments, and ignore the corresponding tangents.
|
||||
# Then nondiff_argnums can be reserved for (1) non jaxtype data (like a
|
||||
# string- or callable-valued argument which parameterizes the function or
|
||||
# rule) or (2) static data (e.g. integers which parameterize shapes).
|
||||
raise unittest.SkipTest("behavior no longer supported")
|
||||
|
||||
@partial(api.custom_jvp, nondiff_argnums=(0,))
|
||||
def f(x, y):
|
||||
return x * y
|
||||
@ -6527,6 +6537,22 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
expected = (6., 5.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_nondiff_arg_vmap_tracer(self):
|
||||
@partial(api.custom_jvp, nondiff_argnums=(0,))
|
||||
def f(x, y):
|
||||
return x * y
|
||||
def f_jvp(x, primals, tangents):
|
||||
(y,), (t_y,) = primals, tangents
|
||||
return f(x, y), 5 * t_y
|
||||
f.defjvp(f_jvp)
|
||||
|
||||
g = jax.vmap(f)
|
||||
|
||||
ans = api.jvp(lambda y: g(jnp.array([2.]), y),
|
||||
(jnp.array([3.]),), (jnp.array([1.]),))
|
||||
expected = (jnp.array([6.]), jnp.array([5.]))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_nondiff_arg_hiding_jvp_tracer(self):
|
||||
def f(x):
|
||||
@partial(api.custom_jvp, nondiff_argnums=(0,))
|
||||
@ -7526,7 +7552,10 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
expected = (2., jnp.cos(1.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_closed_over_tracer(self):
|
||||
def test_closed_over_jit_tracer(self):
|
||||
# See the comment in CustomJVPTest.test_nondiff_arg_jit_tracer.
|
||||
raise unittest.SkipTest("behavior no longer supported")
|
||||
|
||||
# This test is similar to test_nondiff_arg_tracer except it uses lexical
|
||||
# closure rather than the nondiff_argnums mechanism. We decided to disallow
|
||||
# tracers in nondiff_argnums to greatly simplify bookkeeping while still
|
||||
@ -7554,7 +7583,7 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
expected = jnp.cos(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_closed_over_tracer2(self):
|
||||
def test_closed_over_vmap_tracer(self):
|
||||
def outer(x):
|
||||
@api.custom_vjp
|
||||
def f(y):
|
||||
|
Loading…
x
Reference in New Issue
Block a user