mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
4b1bb18909
commit
26c6c3a457
@ -590,3 +590,5 @@ batching.primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vm
|
||||
|
||||
xla.initial_style_translations[custom_vjp_call_jaxpr_p] = \
|
||||
xla.lower_fun_initial_style(_custom_vjp_call_jaxpr_impl)
|
||||
|
||||
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
|
@ -212,6 +212,15 @@ class ODETest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_forward_mode_error(self):
|
||||
# https://github.com/google/jax/issues/3558
|
||||
|
||||
def f(k):
|
||||
return odeint(lambda x, t: k*x, 1., jnp.linspace(0, 1., 50)).sum()
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "can't apply forward-mode.*"):
|
||||
jax.jacfwd(f)(3.)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user