add add_any to jet rules table

fixes #5217
This commit is contained in:
Matthew Johnson 2020-12-17 12:10:04 -08:00
parent 943c7794f9
commit d7b5e3b5d4
2 changed files with 10 additions and 0 deletions

View File

@ -229,6 +229,7 @@ deflinear(lax.complex_p)
deflinear(lax.conj_p)
deflinear(lax.imag_p)
deflinear(lax.add_p)
deflinear(ad_util.add_jaxvals_p)
deflinear(lax.sub_p)
deflinear(lax.convert_element_type_p)
deflinear(lax.broadcast_in_dim_p)

View File

@ -19,6 +19,7 @@ from absl.testing import absltest
import numpy as np
import unittest
import jax
from jax import test_util as jtu
import jax.numpy as jnp
import jax.scipy.special
@ -379,6 +380,14 @@ class JetTest(jtu.JaxTestCase):
assert g_out_primals == f_out_primals
assert g_out_series == f_out_series
def test_add_any(self):
# https://github.com/google/jax/issues/5217
f = lambda x, eps: x * eps + eps + x
def g(eps):
x = jnp.array(1.)
return jax.grad(f)(x, eps)
jet(g, (1.,), ([1.],)) # doesn't crash
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())