Fix error when running a jvp of a jit of a custom_vjp.

Error before:
NotImplementedError: XLA translation rule for primitive 'custom_lin' not found

Error after:
TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
This commit is contained in:
Lena Martens 2021-03-18 20:08:33 +00:00 committed by lenamartens
parent c229574cff
commit 276645dd59
2 changed files with 5 additions and 0 deletions

View File

@ -671,6 +671,7 @@ xla.initial_style_translations[custom_vjp_call_jaxpr_p] = \
xla.lower_fun_initial_style(_custom_vjp_call_jaxpr_impl)
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
xla.translations[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
@config.register_omnistaging_disabler

View File

@ -3803,6 +3803,10 @@ class CustomVJPTest(jtu.JaxTestCase):
TypeError,
r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
lambda: api.jvp(api.vmap(f), (jnp.arange(3.),), (jnp.ones(3),)))
self.assertRaisesRegex(
TypeError,
r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
lambda: api.jvp(jit(f), (3.,), (1.,)))
def test_kwargs(self):
# from https://github.com/google/jax/issues/1938