mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix error when running a jvp of a jit of a custom_vjp.
Error before: NotImplementedError: XLA translation rule for primitive 'custom_lin' not found Error after: TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
This commit is contained in:
parent
c229574cff
commit
276645dd59
@ -671,6 +671,7 @@ xla.initial_style_translations[custom_vjp_call_jaxpr_p] = \
|
||||
xla.lower_fun_initial_style(_custom_vjp_call_jaxpr_impl)
|
||||
|
||||
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
xla.translations[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
|
@ -3803,6 +3803,10 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
TypeError,
|
||||
r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
|
||||
lambda: api.jvp(api.vmap(f), (jnp.arange(3.),), (jnp.ones(3),)))
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
|
||||
lambda: api.jvp(jit(f), (3.,), (1.,)))
|
||||
|
||||
def test_kwargs(self):
|
||||
# from https://github.com/google/jax/issues/1938
|
||||
|
Loading…
x
Reference in New Issue
Block a user