fix error when doing forward-mode of odeint (#3566)

fixes #3558
This commit is contained in:
Matthew Johnson 2020-06-25 20:57:34 -07:00 committed by GitHub
parent 4b1bb18909
commit 26c6c3a457
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 0 deletions

View File

@ -590,3 +590,5 @@ batching.primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vm
xla.initial_style_translations[custom_vjp_call_jaxpr_p] = \
xla.lower_fun_initial_style(_custom_vjp_call_jaxpr_impl)
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp

View File

@ -212,6 +212,15 @@ class ODETest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False, atol=1e-2, rtol=1e-2)
def test_forward_mode_error(self):
# https://github.com/google/jax/issues/3558
def f(k):
return odeint(lambda x, t: k*x, 1., jnp.linspace(0, 1., 50)).sum()
with self.assertRaisesRegex(TypeError, "can't apply forward-mode.*"):
jax.jacfwd(f)(3.)
if __name__ == '__main__':
absltest.main()